diff --git a/.vscode/settings.json b/.vscode/settings.json index 8a485ed3..c4b341f7 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -15,4 +15,4 @@ ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true -} +} \ No newline at end of file diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py index 53b4ab91..a547f68b 100644 --- a/dreadnode/api/client.py +++ b/dreadnode/api/client.py @@ -15,6 +15,7 @@ from dreadnode.api.models import ( AccessRefreshTokenResponse, + ContainerRegistryCredentials, DeviceCodeResponse, ExportFormat, GithubTokenResponse, @@ -22,6 +23,7 @@ Project, RawRun, RawTask, + RegistryImageDetails, Run, RunSummary, StatusFilter, @@ -710,3 +712,55 @@ def get_user_data_credentials(self) -> UserDataCredentials: """ response = self._request("GET", "/user-data/credentials") return UserDataCredentials(**response.json()) + + # Container registry access + + def get_container_registry_credentials(self) -> ContainerRegistryCredentials: + """ + Retrieves container registry credentials for Docker image access. + + Returns: + The container registry credentials object. + """ + response = self.request("POST", "/platform/registry-token") + return ContainerRegistryCredentials(**response.json()) + + def get_platform_releases( + self, tag: str, services: list[str], cli_version: str | None + ) -> RegistryImageDetails: + """ + Resolves the platform releases for the current project. + + Returns: + The resolved platform releases as a ResolveReleasesResponse object. + """ + payload = { + "tag": tag, + "services": services, + "cli_version": cli_version, + } + try: + response = self.request("POST", "/platform/get-releases", json_data=payload) + + except RuntimeError as e: + if "403" in str(e): + raise RuntimeError("You do not have access to platform releases.") from e + + if "404" in str(e): + if "Image not found" in str(e): + raise RuntimeError("Image not found") from e + + raise RuntimeError( + f"Failed to get platform releases: {e}. The feature is likely disabled on this server" + ) from e + raise + return RegistryImageDetails(**response.json()) + + def get_platform_templates(self, tag: str) -> bytes: + """ + Retrieves the available platform templates. + """ + params = {"tag": tag} + response = self.request("GET", "/platform/templates/all", params=params) + zip_content: bytes = response.content + return zip_content diff --git a/dreadnode/api/models.py b/dreadnode/api/models.py index 61c52dda..4fef6956 100644 --- a/dreadnode/api/models.py +++ b/dreadnode/api/models.py @@ -43,6 +43,33 @@ class UserDataCredentials(BaseModel): endpoint: str | None +class ContainerRegistryCredentials(BaseModel): + registry: str + username: str + password: str + expires_at: datetime + + +class PlatformImage(BaseModel): + service: str + uri: str + digest: str + tag: str + + @property + def full_uri(self) -> str: + return f"{self.uri}@{self.digest}" + + @property + def registry(self) -> str: + return self.uri.split("/")[0] + + +class RegistryImageDetails(BaseModel): + tag: str + images: list[PlatformImage] + + # Auth diff --git a/dreadnode/cli/main.py b/dreadnode/cli/main.py index 3a4030f1..777210c7 100644 --- a/dreadnode/cli/main.py +++ b/dreadnode/cli/main.py @@ -20,6 +20,7 @@ download_and_unzip_archive, validate_server_for_clone, ) +from dreadnode.cli.platform import cli as platform_cli from dreadnode.cli.profile import cli as profile_cli from dreadnode.constants import DEBUG, PLATFORM_BASE_URL from dreadnode.user_config import ServerConfig, UserConfig @@ -28,8 +29,9 @@ cli["--help"].group = "Meta" -cli.command(profile_cli) cli.command(agent_cli) +cli.command(platform_cli) +cli.command(profile_cli) @cli.meta.default diff --git a/dreadnode/cli/platform/__init__.py b/dreadnode/cli/platform/__init__.py new file mode 100644 index 00000000..7a874c7c --- /dev/null +++ b/dreadnode/cli/platform/__init__.py @@ -0,0 +1,3 @@ +from dreadnode.cli.platform.cli import cli + +__all__ = ["cli"] diff --git a/dreadnode/cli/platform/cli.py b/dreadnode/cli/platform/cli.py new file mode 100644 index 00000000..ac55e4d3 --- /dev/null +++ b/dreadnode/cli/platform/cli.py @@ -0,0 +1,61 @@ +import cyclopts + +from dreadnode.cli.platform.configure import configure_platform +from dreadnode.cli.platform.download import download_platform +from dreadnode.cli.platform.login import log_into_registries +from dreadnode.cli.platform.start import start_platform +from dreadnode.cli.platform.stop import stop_platform +from dreadnode.cli.platform.upgrade import upgrade_platform + +cli = cyclopts.App("platform", help="Run and manage the platform.", help_flags=[]) + + +@cli.command() +def start(tag: str | None = None) -> None: + """Start the platform. Optionally, provide a tagged version to start. + + Args: + tag: Optional image tag to use when starting the platform. + """ + start_platform(tag=tag) + + +@cli.command(name=["stop", "down"]) +def stop() -> None: + """Stop the running platform.""" + stop_platform() + + +@cli.command() +def download(tag: str | None = None) -> None: + """Download platform files for a specific tag. + + Args: + tag: Optional image tag to download. + """ + download_platform(tag=tag) + + +@cli.command() +def upgrade() -> None: + """Upgrade the platform to the latest version.""" + upgrade_platform() + + +@cli.command() +def refresh_registry_auth() -> None: + """Refresh container registry credentials for platform access. + + Used for out of band Docker management. + """ + log_into_registries() + + +@cli.command() +def configure(service: str = "api") -> None: + """Configure the platform for a specific service. + + Args: + service: The name of the service to configure. + """ + configure_platform(service=service) diff --git a/dreadnode/cli/platform/configure.py b/dreadnode/cli/platform/configure.py new file mode 100644 index 00000000..bf2f5707 --- /dev/null +++ b/dreadnode/cli/platform/configure.py @@ -0,0 +1,21 @@ +from dreadnode.cli.platform.utils.env_mgmt import open_env_file +from dreadnode.cli.platform.utils.printing import print_info +from dreadnode.cli.platform.utils.versions import get_current_version, get_local_cache_dir + + +def configure_platform(service: str = "api", tag: str | None = None) -> None: + """Configure the platform for a specific service. + + Args: + service: The name of the service to configure. + """ + if not tag: + current_version = get_current_version() + tag = current_version.tag if current_version else "latest" + + print_info(f"Configuring {service} service...") + env_file = get_local_cache_dir() / tag / f".{service}.env" + open_env_file(env_file) + print_info( + f"Configuration for {service} service loaded. It will take effect the next time the service is started." + ) diff --git a/dreadnode/cli/platform/constants.py b/dreadnode/cli/platform/constants.py new file mode 100644 index 00000000..8042b5c5 --- /dev/null +++ b/dreadnode/cli/platform/constants.py @@ -0,0 +1,9 @@ +import typing as t + +API_SERVICE = "api" +UI_SERVICE = "ui" +SERVICES = [API_SERVICE, UI_SERVICE] +VERSIONS_MANIFEST = "versions.json" + +SupportedArchitecture = t.Literal["amd64", "arm64"] +SUPPORTED_ARCHITECTURES: list[SupportedArchitecture] = ["amd64", "arm64"] diff --git a/dreadnode/cli/platform/docker_.py b/dreadnode/cli/platform/docker_.py new file mode 100644 index 00000000..db13cf58 --- /dev/null +++ b/dreadnode/cli/platform/docker_.py @@ -0,0 +1,266 @@ +import json +import subprocess +import time +from pathlib import Path + +from dreadnode.cli.api import create_api_client +from dreadnode.cli.platform.utils.printing import print_error, print_info, print_success + + +def _run_docker_compose_command( + args: list[str], + compose_file: Path, + timeout: int = 300, + stdin_input: str | None = None, +) -> subprocess.CompletedProcess[str]: + """Execute a docker compose command with common error handling and configuration. + + Args: + args: Additional arguments for the docker compose command. + compose_file: Path to docker-compose file. + timeout: Command timeout in seconds. + command_name: Name of the command for error messages. + stdin_input: Input to pass to stdin (for commands like docker login). + + Returns: + CompletedProcess object with command results. + + Raises: + subprocess.CalledProcessError: If command fails. + subprocess.TimeoutExpired: If command times out. + FileNotFoundError: If docker/docker-compose not found. + """ + cmd = ["docker", "compose"] + + # Add compose file + cmd.extend(["-f", compose_file.as_posix()]) + + # Add the specific command arguments + cmd.extend(args) + + cmd_str = " ".join(cmd) + + try: + # Remove capture_output=True to allow real-time streaming + # stdout and stderr will go directly to the terminal + result = subprocess.run( # noqa: S603 + cmd, + check=True, + text=True, + timeout=timeout, + encoding="utf-8", + errors="replace", + input=stdin_input, + ) + + except subprocess.CalledProcessError as e: + print_error(f"{cmd_str} failed with exit code {e.returncode}") + raise + + except subprocess.TimeoutExpired: + print_error(f"{cmd_str} timed out after {timeout} seconds") + raise + + except FileNotFoundError: + print_error("Docker or docker compose not found. Please ensure Docker is installed.") + raise + + return result + + +def get_origin(ui_container: str) -> str | None: + """ + Get the ORIGIN environment variable from the UI container and return + a friendly message for the user. + + Args: + ui_container: Name of the UI container (default: dreadnode-ui). + + Returns: + str | None: Message with the origin URL, or None if not found. + """ + try: + cmd = [ + "docker", + "inspect", + "-f", + "{{range .Config.Env}}{{println .}}{{end}}", + ui_container, + ] + cp = subprocess.run( # noqa: S603 + cmd, + check=True, + text=True, + capture_output=True, + ) + + for line in cp.stdout.splitlines(): + if line.startswith("ORIGIN="): + return line.split("=", 1)[1] + + except subprocess.CalledProcessError: + return None + + return None + + +def _check_docker_creds_exist(registry: str) -> bool: + """Check if Docker credentials exist for the specified registry. + + Args: + registry: Registry hostname to check credentials for. + + Returns: + bool: True if credentials exist, False otherwise. + """ + config_path = Path.home() / ".docker" / "config.json" + + if not config_path.exists(): + return False + + try: + with config_path.open() as f: + config = json.load(f) + + auths = config.get("auths", {}) + except (json.JSONDecodeError, KeyError): + return False + return registry in auths + + +def _are_docker_creds_fresh(registry: str, max_age_hours: int = 1) -> bool: + """Check if Docker credentials are fresh (recently updated). + + Args: + registry: Registry hostname to check credentials for. + max_age_hours: Maximum age in hours for credentials to be considered fresh. + + Returns: + bool: True if credentials are fresh, False otherwise. + """ + config_path = Path.home() / ".docker" / "config.json" + + if not config_path.exists(): + return False + + # Check file modification time + mtime = config_path.stat().st_mtime + age_hours = (time.time() - mtime) / 3600 + + return age_hours < max_age_hours and _check_docker_creds_exist(registry) + + +def _check_docker_installed() -> bool: + """Check if Docker is installed on the system.""" + try: + cmd = ["docker", "--version"] + subprocess.run( # noqa: S603 + cmd, + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + except subprocess.CalledProcessError: + print_error("Docker is not installed. Please install Docker and try again.") + return False + + return True + + +def _check_docker_compose_installed() -> bool: + """Check if Docker Compose is installed on the system.""" + try: + cmd = ["docker", "compose", "--version"] + subprocess.run( # noqa: S603 + cmd, + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + except subprocess.CalledProcessError: + print_error("Docker Compose is not installed. Please install Docker Compose and try again.") + return False + return True + + +def docker_requirements_met() -> bool: + """Check if Docker and Docker Compose are installed.""" + return _check_docker_installed() and _check_docker_compose_installed() + + +def docker_login(registry: str) -> None: + """Log into a Docker registry using API credentials. + + Args: + registry: Registry hostname to log into. + + Raises: + subprocess.CalledProcessError: If docker login command fails. + """ + if _are_docker_creds_fresh(registry): + print_info(f"Docker credentials for {registry} are fresh. Skipping login.") + return + + print_info(f"Logging in to Docker registry: {registry} ...") + client = create_api_client() + container_registry_creds = client.get_container_registry_credentials() + + cmd = ["docker", "login", container_registry_creds.registry] + cmd.extend(["--username", container_registry_creds.username]) + cmd.extend(["--password-stdin"]) + + try: + subprocess.run( # noqa: S603 + cmd, + input=container_registry_creds.password, + text=True, + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + print_success("Logged in to container registry ...") + except subprocess.CalledProcessError as e: + print_error(f"Failed to log in to container registry: {e}") + raise + + +def docker_run( + compose_file: Path, + timeout: int = 300, +) -> subprocess.CompletedProcess[str]: + """Run docker containers for the platform. + + Args: + compose_file: Path to docker-compose file. + timeout: Command timeout in seconds. + + Returns: + CompletedProcess object with command results. + + Raises: + subprocess.CalledProcessError: If command fails. + subprocess.TimeoutExpired: If command times out. + """ + + return _run_docker_compose_command(["up", "-d"], compose_file, timeout, "Docker compose up") + + +def docker_stop( + compose_file: Path, + timeout: int = 300, +) -> subprocess.CompletedProcess[str]: + """Stop docker containers for the platform. + + Args: + compose_file: Path to docker-compose file. + timeout: Command timeout in seconds. + + Returns: + CompletedProcess object with command results. + + Raises: + subprocess.CalledProcessError: If command fails. + subprocess.TimeoutExpired: If command times out. + """ + return _run_docker_compose_command(["down"], compose_file, timeout, "Docker compose down") diff --git a/dreadnode/cli/platform/download.py b/dreadnode/cli/platform/download.py new file mode 100644 index 00000000..d394098b --- /dev/null +++ b/dreadnode/cli/platform/download.py @@ -0,0 +1,144 @@ +import io +import json +import zipfile + +from dreadnode.api.models import RegistryImageDetails +from dreadnode.cli.api import create_api_client +from dreadnode.cli.platform.constants import SERVICES, VERSIONS_MANIFEST +from dreadnode.cli.platform.schemas import LocalVersionSchema +from dreadnode.cli.platform.utils.env_mgmt import ( + create_default_env_files, +) +from dreadnode.cli.platform.utils.printing import ( + print_error, + print_info, + print_success, + print_warning, +) +from dreadnode.cli.platform.utils.versions import ( + confirm_with_context, + create_local_latest_tag, + get_available_local_versions, + get_cli_version, + get_local_cache_dir, +) + + +def _resolve_latest(tag: str) -> str: + """Resolve 'latest' tag to actual version tag from API. + + Args: + tag: Version tag that contains 'latest'. + + Returns: + str: Resolved actual version tag. + """ + api_client = create_api_client() + release_info = api_client.get_platform_releases( + tag, services=SERVICES, cli_version=get_cli_version() + ) + return release_info.tag + + +def _create_local_version_file_structure( + tag: str, release_info: RegistryImageDetails +) -> LocalVersionSchema: + """Create local file structure and update manifest for a new version. + + Args: + tag: Version tag to create structure for. + release_info: Registry image details from API. + + Returns: + LocalVersionSchema: Created local version schema. + """ + available_local_versions = get_available_local_versions() + + # Create a new local version schema + local_cache_dir = get_local_cache_dir() + new_version = LocalVersionSchema( + **release_info.model_dump(), + local_path=local_cache_dir / tag, + current=False, + ) + + # Add the new version to the available local versions + available_local_versions.versions.append(new_version) + + # sort the manifest by semver, newest first + available_local_versions.versions.sort(key=lambda v: v.tag, reverse=True) + + # update the manifest file + manifest_path = local_cache_dir / VERSIONS_MANIFEST + with manifest_path.open(encoding="utf-8", mode="w") as f: + json.dump(available_local_versions.model_dump(), f, indent=2) + + print_success(f"Updated versions manifest at {manifest_path} with {new_version.tag}") + + if new_version.local_path.exists(): + print_warning(f"Version {tag} already exists locally.") + if not confirm_with_context("overwrite it?"): + print_error("Aborting download.") + return new_version + + # create the directory + new_version.local_path.mkdir(parents=True, exist_ok=True) + + return new_version + + +def _download_version_files(tag: str) -> LocalVersionSchema: + """Download platform version files from API and extract locally. + + Args: + tag: Version tag to download. + + Returns: + LocalVersionSchema: Downloaded local version schema. + """ + api_client = create_api_client() + release_info = api_client.get_platform_releases( + tag, services=SERVICES, cli_version=get_cli_version() + ) + zip_content = api_client.get_platform_templates(tag) + + new_local_version = _create_local_version_file_structure(release_info.tag, release_info) + + with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_file: + zip_file.extractall(new_local_version.local_path) + print_success(f"Downloaded version {tag} to {new_local_version.local_path}") + + create_default_env_files(new_local_version) + return new_local_version + + +def download_platform(tag: str | None = None) -> LocalVersionSchema: + """Download platform version if not already available locally. + + Args: + tag: Version tag to download (supports 'latest'). + + Returns: + LocalVersionSchema: Local version schema for the downloaded/existing version. + """ + if not tag or tag == "latest": + # all remote images are tagged with architecture + tag = create_local_latest_tag() + + if "latest" in tag: + tag = _resolve_latest(tag) + + # get what's available + available_local_versions = get_available_local_versions() + + # if there are versions available + if available_local_versions.versions: + for available_local_version in available_local_versions.versions: + if tag == available_local_version.tag: + print_success( + f"Version {tag} is already downloaded at {available_local_version.local_path}" + ) + return available_local_version + + print_info(f"Version {tag} is not available locally. Will download it.") + return _download_version_files(tag) diff --git a/dreadnode/cli/platform/login.py b/dreadnode/cli/platform/login.py new file mode 100644 index 00000000..ab020199 --- /dev/null +++ b/dreadnode/cli/platform/login.py @@ -0,0 +1,20 @@ +from dreadnode.cli.platform.docker_ import docker_login +from dreadnode.cli.platform.utils.printing import print_info +from dreadnode.cli.platform.utils.versions import get_current_version + + +def log_into_registries() -> None: + """Log into all Docker registries for the current platform version. + + Iterates through all images in the current version and logs into their + respective registries. If no current version is set, displays an error message. + """ + current_version = get_current_version() + if not current_version: + print_info("There are no registries configured. Run `dreadnode platform start` to start.") + return + registries_attempted = set() + for image in current_version.images: + if image.registry not in registries_attempted: + docker_login(image.registry) + registries_attempted.add(image.registry) diff --git a/dreadnode/cli/platform/schemas.py b/dreadnode/cli/platform/schemas.py new file mode 100644 index 00000000..5029452f --- /dev/null +++ b/dreadnode/cli/platform/schemas.py @@ -0,0 +1,83 @@ +from pathlib import Path + +from pydantic import BaseModel, field_serializer + +from dreadnode.api.models import RegistryImageDetails +from dreadnode.cli.platform.constants import API_SERVICE, UI_SERVICE + + +class LocalVersionSchema(RegistryImageDetails): + local_path: Path + current: bool + + @field_serializer("local_path") + def serialize_path(self, path: Path) -> str: + """Serialize Path object to absolute path string. + + Args: + path: Path object to serialize. + + Returns: + str: Absolute path as string. + """ + return str(path.resolve()) # Convert to absolute path string + + @property + def compose_file(self) -> Path: + return self.local_path / "docker-compose.yaml" + + @property + def api_env_file(self) -> Path: + return self.local_path / f".{API_SERVICE}.env" + + @property + def api_example_env_file(self) -> Path: + return self.local_path / f".{API_SERVICE}.example.env" + + @property + def ui_env_file(self) -> Path: + return self.local_path / f".{UI_SERVICE}.env" + + @property + def ui_example_env_file(self) -> Path: + return self.local_path / f".{UI_SERVICE}.example.env" + + def get_env_path_by_service(self, service: str) -> Path: + """Get environment file path for a specific service. + + Args: + service: Service name to get env path for. + + Returns: + Path: Path to the service's environment file. + + Raises: + ValueError: If service is not recognized. + """ + if service == API_SERVICE: + return self.api_env_file + if service == UI_SERVICE: + return self.ui_env_file + raise ValueError(f"Unknown service: {service}") + + def get_example_env_path_by_service(self, service: str) -> Path: + """Get example environment file path for a specific service. + + Args: + service: Service name to get example env path for. + + Returns: + Path: Path to the service's example environment file. + + Raises: + ValueError: If service is not recognized. + """ + if service == API_SERVICE: + return self.api_example_env_file + if service == UI_SERVICE: + return self.ui_example_env_file + raise ValueError(f"Unknown service: {service}") + + +class LocalVersionsSchema(BaseModel): + versions: list[LocalVersionSchema] diff --git a/dreadnode/cli/platform/start.py b/dreadnode/cli/platform/start.py new file mode 100644 index 00000000..93654a29 --- /dev/null +++ b/dreadnode/cli/platform/start.py @@ -0,0 +1,54 @@ +from dreadnode.cli.platform.docker_ import ( + docker_login, + docker_requirements_met, + docker_run, + get_origin, +) +from dreadnode.cli.platform.download import download_platform +from dreadnode.cli.platform.utils.env_mgmt import generate_env_file +from dreadnode.cli.platform.utils.printing import print_error, print_info, print_success +from dreadnode.cli.platform.utils.versions import ( + create_local_latest_tag, + get_current_version, + mark_current_version, +) + + +def start_platform(tag: str | None = None) -> None: + """Start the platform with the specified or current version. + + Args: + tag: Optional image tag to use. If not provided, uses the current + version or downloads the latest available version. + """ + if not docker_requirements_met(): + print_error("Docker and Docker Compose must be installed to start the platform.") + return + + if tag: + selected_version = download_platform(tag) + mark_current_version(selected_version) + elif current_version := get_current_version(): + selected_version = current_version + # no need to mark + else: + latest_tag = create_local_latest_tag() + selected_version = download_platform(latest_tag) + mark_current_version(selected_version) + + registries_attempted = set() + for image in selected_version.images: + if image.registry not in registries_attempted: + docker_login(image.registry) + registries_attempted.add(image.registry) + generate_env_file(selected_version) + print_info(f"Starting platform: {selected_version.tag}") + docker_run(selected_version.compose_file) + print_success(f"Platform {selected_version.tag} started successfully.") + origin = get_origin("dreadnode-ui") + if origin: + print_info("You can access the app at the following URLs:") + print_info(f" - {origin}") + else: + print_info(" - Unable to determine the app URL.") + print_info("Please check the container logs for more information.") diff --git a/dreadnode/cli/platform/stop.py b/dreadnode/cli/platform/stop.py new file mode 100644 index 00000000..62e29397 --- /dev/null +++ b/dreadnode/cli/platform/stop.py @@ -0,0 +1,21 @@ +from dreadnode.cli.platform.docker_ import docker_stop +from dreadnode.cli.platform.utils.env_mgmt import remove_generated_env_file +from dreadnode.cli.platform.utils.printing import print_error, print_success +from dreadnode.cli.platform.utils.versions import ( + get_current_version, +) + + +def stop_platform() -> None: + """Stop the currently running platform. + + Uses the current version's compose file to stop all platform containers + via docker compose down. + """ + current_version = get_current_version() + if not current_version: + print_error("No current version found. Nothing to stop.") + return + docker_stop(current_version.compose_file) + print_success("Platform stopped successfully.") + remove_generated_env_file(current_version) diff --git a/dreadnode/cli/platform/upgrade.py b/dreadnode/cli/platform/upgrade.py new file mode 100644 index 00000000..4c76d8c2 --- /dev/null +++ b/dreadnode/cli/platform/upgrade.py @@ -0,0 +1,81 @@ +from dreadnode.cli.platform.constants import SERVICES +from dreadnode.cli.platform.docker_ import docker_stop +from dreadnode.cli.platform.download import download_platform +from dreadnode.cli.platform.start import start_platform +from dreadnode.cli.platform.utils.env_mgmt import ( + merge_env_files_content, +) +from dreadnode.cli.platform.utils.printing import print_error, print_info +from dreadnode.cli.platform.utils.versions import ( + confirm_with_context, + create_local_latest_tag, + get_current_version, + get_semver_from_tag, + mark_current_version, + newer_remote_version, +) + + +def upgrade_platform() -> None: + """Upgrade the platform to the latest available version. + + Downloads the latest version, compares it with the current version, + and performs the upgrade if a newer version is available. Optionally + merges configuration files from the current version to the new version. + Stops the current platform and starts the upgraded version. + """ + latest_tag = create_local_latest_tag() + + latest_version = download_platform(latest_tag) + current_local_version = get_current_version() + + if not current_local_version: + print_error( + "No current platform version found. Run `dreadnode platform start` to start the latest version." + ) + return + + current_semver = get_semver_from_tag(current_local_version.tag) + remote_semver = get_semver_from_tag(latest_version.tag) + + if not newer_remote_version(current_semver, remote_semver): + print_info(f"You are using the latest ({current_semver}) version of the platform.") + return + + if not confirm_with_context( + f"Are you sure you want to upgrade from {current_local_version.tag} to {latest_version.tag}?" + ): + print_error("Aborting upgrade.") + return + + if confirm_with_context( + f"Would you like to attempt to merge configuration files from {current_local_version.tag} to {latest_version.tag}?" + ): + for service in SERVICES: + original_env_file = current_local_version.get_example_env_path_by_service(service) + with original_env_file.open() as f: + original_env_content = f.read() + current_env_file = current_local_version.get_env_path_by_service(service) + with current_env_file.open() as f: + current_env_content = f.read() + new_env_file = latest_version.get_env_path_by_service(service) + with new_env_file.open() as f: + new_env_content = f.read() + merged_env_content = merge_env_files_content( + original_env_content, current_env_content, new_env_content + ) + with new_env_file.open("w") as f: + f.write(merged_env_content) + print_info(f" - Merged .env file for {service}: {merged_env_content}") + + print_info(".env files merged.") + + else: + print_info("Skipping .env file merge.") + + print_info(f"Stopping current platform version {current_local_version.tag}...") + docker_stop(current_local_version.compose_file) + print_info(f"Current platform version {current_local_version.tag} stopped.") + + mark_current_version(latest_version) + start_platform() diff --git a/dreadnode/cli/platform/utils/__init__.py b/dreadnode/cli/platform/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dreadnode/cli/platform/utils/env_mgmt.py b/dreadnode/cli/platform/utils/env_mgmt.py new file mode 100644 index 00000000..30a93a75 --- /dev/null +++ b/dreadnode/cli/platform/utils/env_mgmt.py @@ -0,0 +1,429 @@ +import subprocess +import sys +import typing as t +from pathlib import Path + +from dreadnode.cli.platform.constants import ( + SERVICES, +) +from dreadnode.cli.platform.schemas import LocalVersionSchema +from dreadnode.cli.platform.utils.printing import print_error, print_info + +LineTypes = t.Literal["variable", "comment", "empty"] + + +class _EnvLine(t.NamedTuple): + """Represents a line in an .env file with its type and content.""" + + line_type: LineTypes + key: str | None = None + value: str = "" + original_line: str = "" + + +def _parse_env_lines(content: str) -> list[_EnvLine]: + """ + Parse .env file content into structured lines preserving all formatting. + + Args: + content (str): The content of the .env file + + Returns: + List[EnvLine]: List of parsed lines with their types + """ + lines = [] + + for line in content.split("\n"): + stripped = line.strip() + + if not stripped: + # Empty line + lines.append(_EnvLine("empty", original_line=line)) + elif stripped.startswith("#"): + # Comment line + lines.append(_EnvLine("comment", original_line=line)) + elif "=" in stripped: + # Variable line + key, value = stripped.split("=", 1) + lines.append(_EnvLine("variable", key.strip(), value.strip(), line)) + else: + # Treat as comment/invalid line to preserve it + lines.append(_EnvLine("comment", original_line=line)) + + return lines + + +def _extract_variables(lines: list[_EnvLine]) -> dict[str, str]: + """Extract just the variables from parsed lines. + + Args: + lines: List of parsed environment file lines. + + Returns: + dict[str, str]: Dictionary mapping variable names to their values. + """ + return { + line.key: line.value + for line in lines + if line.line_type == "variable" and line.key is not None + } + + +def _merge_env_files( + original_remote_content: str, + current_local_content: str, + updated_remote_content: str, +) -> dict[str, str]: + """ + Merge .env files with the following logic: + 1. Local changes (updates/additions) take precedence over remote defaults + 2. Remote removals remove the key from local (unless locally modified) + 3. Remote additions are added to local + 4. Local additions are preserved + + Args: + original_remote_content (str): Original remote .env content (baseline) + current_local_content (str): Current local .env content (with local changes) + updated_remote_content (str): Updated remote .env content (new remote state) + + Returns: + Dict[str, str]: Merged variables dictionary + """ + # Extract variables from each file + original_remote = _extract_variables(_parse_env_lines(original_remote_content)) + current_local = _extract_variables(_parse_env_lines(current_local_content)) + updated_remote = _extract_variables(_parse_env_lines(updated_remote_content)) + + # Result dictionary to build the merged content + merged = {} + + # Step 1: Start with current local content (preserves local changes and additions) + merged.update(current_local) + + # Step 2: Add new keys from updated remote (remote additions) + merged.update( + { + key: value + for key, value in updated_remote.items() + if key not in original_remote + and key not in current_local # New remote addition not already locally added + } + ) + + # Step 3: Handle remote removals + for key in original_remote: + # Only remove if the key was removed in remote and the local value matches the original remote value + if ( + key not in updated_remote + and key in current_local + and current_local[key] == original_remote[key] + ): + merged.pop(key, None) + + # Step 4: Update values for keys that exist in both updated remote and weren't locally modified + merged.update( + { + key: remote_value + for key, remote_value in updated_remote.items() + if ( + key in original_remote + and key in current_local + and current_local[key] == original_remote[key] + ) + } + ) + + return merged + + +def _find_insertion_points( + base_lines: list[_EnvLine], remote_lines: list[_EnvLine], new_vars: dict[str, str] +) -> dict[str, int]: + """Find the best insertion points for new variables based on remote file structure. + + Args: + base_lines: Lines from local file. + remote_lines: Lines from remote file. + new_vars: New variables to place. + + Returns: + dict[str, int]: Dict mapping variable names to insertion indices in base_lines. + """ + insertion_points = {} + + # Build a map of variable positions in the remote file + remote_var_positions = {} + remote_var_context = {} + + for i, line in enumerate(remote_lines): + if line.line_type == "variable": + remote_var_positions[line.key] = i + # Capture context (preceding comment/section) + context_lines: list[str] = [] + j = i - 1 + while j >= 0 and remote_lines[j].line_type in ["comment", "empty"]: + if remote_lines[j].line_type == "comment": + context_lines.insert(0, remote_lines[j].original_line) + break # Stop at first comment (section header) + j -= 1 + remote_var_context[line.key] = context_lines + + # Build a map of variable positions in the local file + local_var_positions = {} + for i, line in enumerate(base_lines): + if line.line_type == "variable": + local_var_positions[line.key] = i + + # For each new variable, find the best insertion point + for new_var in new_vars: + if new_var not in remote_var_positions: + # Variable not in remote, place at end + insertion_points[new_var] = len(base_lines) + continue + + remote_pos = remote_var_positions[new_var] + + # Find variables that appear before this one in the remote file + preceding_vars = [ + var + for var, pos in remote_var_positions.items() + if pos < remote_pos and var in local_var_positions + ] + + # Find variables that appear after this one in the remote file + following_vars = [ + var + for var, pos in remote_var_positions.items() + if pos > remote_pos and var in local_var_positions + ] + + if preceding_vars: + # Place after the last preceding variable that exists locally + last_preceding = max(preceding_vars, key=lambda v: local_var_positions[v]) + insertion_points[new_var] = local_var_positions[last_preceding] + 1 + elif following_vars: + # Place before the first following variable that exists locally + first_following = min(following_vars, key=lambda v: local_var_positions[v]) + insertion_points[new_var] = local_var_positions[first_following] + else: + # No context, place at end + insertion_points[new_var] = len(base_lines) + + return insertion_points + + +def _reconstruct_env_content( # noqa: PLR0912 + base_lines: list[_EnvLine], merged_vars: dict[str, str], updated_remote_lines: list[_EnvLine] +) -> str: + """Reconstruct .env content preserving structure from base while applying merged variables. + + Args: + base_lines: Parsed lines from the local file (for structure). + merged_vars: Dictionary of merged variables. + updated_remote_lines: Parsed lines from updated remote (for new additions). + + Returns: + str: Reconstructed .env content. + """ + result_lines: list[str] = [] + processed_keys = set() + + # Identify new variables that need to be inserted + existing_keys = {line.key for line in base_lines if line.line_type == "variable"} + new_vars = {k: v for k, v in merged_vars.items() if k not in existing_keys} + + # Find optimal insertion points for new variables + insertion_points = _find_insertion_points(base_lines, updated_remote_lines, new_vars) + + # Group new variables by insertion point + vars_by_insertion: dict[int, list[str]] = {} + for var, insertion_idx in insertion_points.items(): + if insertion_idx not in vars_by_insertion: + vars_by_insertion[insertion_idx] = [] + vars_by_insertion[insertion_idx].append(var) + + # Process base structure, inserting new variables at appropriate points + for i, line in enumerate(base_lines): + # Insert new variables that belong before this line + if i in vars_by_insertion: + # Add context comments if this is a new section + added_section_break = False + for var in vars_by_insertion[i]: + # Check if we need a section break (empty line before new variables) + if not added_section_break and result_lines and result_lines[-1].strip(): + # Look for context from remote file + remote_context = None + for remote_line in updated_remote_lines: + if remote_line.line_type == "variable" and remote_line.key == var: + # Find preceding comment in remote file + remote_idx = updated_remote_lines.index(remote_line) + for j in range(remote_idx - 1, -1, -1): + if updated_remote_lines[j].line_type == "comment": + remote_context = updated_remote_lines[j].original_line + break + if updated_remote_lines[j].line_type == "variable": + break + break + + # Add section break with context comment if available + if remote_context: + result_lines.append("") # Empty line + result_lines.append(remote_context) # Section comment + elif i > 0 and base_lines[i - 1].line_type == "variable": + result_lines.append("") # Just empty line for separation + + added_section_break = True + + # Add the new variable + result_lines.append(f"{var}={new_vars[var]}") + processed_keys.add(var) + + # Process the current line + if line.line_type == "variable": + if line.key in merged_vars: + # Keep the variable, potentially with updated value + new_value = merged_vars[line.key] + if line.value == new_value: + # Value unchanged, keep original formatting + result_lines.append(line.original_line) + else: + # Value changed, reconstruct line maintaining original key formatting + original_key_part = line.original_line.split("=")[0] + result_lines.append(f"{original_key_part}={new_value}") + processed_keys.add(line.key) + # If key not in merged_vars, it was removed, so skip it + else: + # Preserve comments and empty lines + result_lines.append(line.original_line) + + # Handle any remaining new variables (those that should go at the very end) + end_insertion_idx = len(base_lines) + if end_insertion_idx in vars_by_insertion: + if result_lines and result_lines[-1].strip(): # Add separator if needed + result_lines.append("") + result_lines.extend( + [ + f"{var}={new_vars[var]}" + for var in vars_by_insertion[end_insertion_idx] + if var not in processed_keys + ] + ) + + # Join lines + return "\n".join(result_lines) + + +def merge_env_files_content( + original_remote_content: str, current_local_content: str, updated_remote_content: str +) -> str: + """Main function to merge .env file contents preserving formatting and structure. + + Args: + original_remote_content: Original remote .env content. + current_local_content: Current local .env content. + updated_remote_content: Updated remote .env content. + + Returns: + str: Merged .env file content with preserved formatting. + """ + # Get the merged variables using the original logic + merged_vars = _merge_env_files( + original_remote_content, current_local_content, updated_remote_content + ) + + # Parse the local file structure to preserve its formatting + local_lines = _parse_env_lines(current_local_content) + updated_remote_lines = _parse_env_lines(updated_remote_content) + + # Reconstruct content preserving local structure but with merged variables + return _reconstruct_env_content(local_lines, merged_vars, updated_remote_lines) + + +def create_default_env_files(current_version: LocalVersionSchema) -> None: + """Create default environment files for all services in the current version. + + Copies sample environment files to actual environment files if they don't exist, + and creates a combined .env file from API and UI environment files. + + Args: + current_version: The current local version schema containing service information. + + Raises: + RuntimeError: If sample environment files are not found or .env file creation fails. + """ + for service in SERVICES: + for image in current_version.images: + if image.service == service: + env_file_path = current_version.get_env_path_by_service(service) + if not env_file_path.exists(): + # copy the sample + sample_env_file_path = current_version.get_example_env_path_by_service(service) + if sample_env_file_path.exists(): + print_info(f"Copying {sample_env_file_path} to {env_file_path}...") + env_file_path.write_text(sample_env_file_path.read_text()) + else: + print_error( + f"Sample environment file for {service} not found at {sample_env_file_path}." + ) + raise RuntimeError( + f"Sample environment file for {service} not found. Cannot configure {service}." + ) + + +def generate_env_file(current_version: LocalVersionSchema) -> None: + """Generate a .env file for the current version by concatenating API and UI environment files. + + This file is used by Docker Compose. + + Args: + current_version: The current local version schema containing service information. + + Returns: + None + + Raises: + RuntimeError: If .env file creation fails. + """ + env_file = current_version.local_path / ".env" + + api_env_file = current_version.api_env_file + ui_env_file = current_version.ui_env_file + + if api_env_file.exists() and ui_env_file.exists(): + print_info(f"Concatenating {api_env_file} and {ui_env_file} into {env_file}...") + with env_file.open("w") as outfile: + outfile.write("# WARNING: This file is auto-generated. Do not edit directly.\n") + for fname in (api_env_file, ui_env_file): + with fname.open() as infile: + outfile.write(infile.read()) + else: + print_error(f"One or both environment files not found: {api_env_file}, {ui_env_file}.") + raise RuntimeError("Failed to create .env file.") + + +def remove_generated_env_file(current_version: LocalVersionSchema) -> None: + """Remove the generated .env file for the current version. + + Args: + current_version: The current local version schema containing service information. + """ + env_file = current_version.local_path / ".env" + if env_file.exists(): + env_file.unlink() + + +def open_env_file(filename: Path) -> None: + """Open the specified environment file in the default editor. + + Args: + filename: The path to the environment file to open. + """ + if sys.platform == "darwin": + cmd = ["open", "-t", filename.as_posix()] + else: + cmd = ["xdg-open", filename.as_posix()] + try: + subprocess.run(cmd, check=False) # noqa: S603 + print_info("Opened environment file.") + except subprocess.CalledProcessError as e: + print_error(f"Failed to open environment file: {e}") diff --git a/dreadnode/cli/platform/utils/printing.py b/dreadnode/cli/platform/utils/printing.py new file mode 100644 index 00000000..2691e44e --- /dev/null +++ b/dreadnode/cli/platform/utils/printing.py @@ -0,0 +1,43 @@ +import sys + +import rich + + +def print_success(message: str, prefix: str | None = None) -> None: + """Print success message in green""" + prefix = prefix or "✓" + rich.print(f"[bold green]{prefix}[/] [green]{message}[/]") + + +def print_error(message: str, prefix: str | None = None) -> None: + """Print error message in red""" + prefix = prefix or "✗" + rich.print(f"[bold red]{prefix}[/] [red]{message}[/]", file=sys.stderr) + + +def print_warning(message: str, prefix: str | None = None) -> None: + """Print warning message in yellow""" + prefix = prefix or "⚠" + rich.print(f"[bold yellow]{prefix}[/] [yellow]{message}[/]") + + +def print_info(message: str, prefix: str | None = None) -> None: + """Print info message in blue""" + prefix = prefix or "i" + rich.print(f"[bold blue]{prefix}[/] [blue]{message}[/]") + + +def print_debug(message: str, prefix: str | None = None) -> None: + """Print debug message in dim gray""" + prefix = prefix or "🐛" + rich.print(f"[dim]{prefix}[/] [dim]{message}[/]") + + +def print_heading(message: str) -> None: + """Print section heading""" + rich.print(f"\n[bold underline]{message}[/]\n") + + +def print_muted(message: str) -> None: + """Print muted text""" + rich.print(f"[dim]{message}[/]") diff --git a/dreadnode/cli/platform/utils/versions.py b/dreadnode/cli/platform/utils/versions.py new file mode 100644 index 00000000..6c1f3b65 --- /dev/null +++ b/dreadnode/cli/platform/utils/versions.py @@ -0,0 +1,165 @@ +import importlib.metadata +import json +import platform +from pathlib import Path + +from packaging.version import Version +from rich.prompt import Confirm + +from dreadnode.cli.platform.constants import ( + SUPPORTED_ARCHITECTURES, + VERSIONS_MANIFEST, + SupportedArchitecture, +) +from dreadnode.cli.platform.schemas import LocalVersionSchema, LocalVersionsSchema +from dreadnode.constants import DEFAULT_LOCAL_STORAGE_DIR + + +def _get_local_arch() -> SupportedArchitecture: + """Get the local machine architecture in supported format. + + Returns: + SupportedArchitecture: The architecture as either "amd64" or "arm64". + + Raises: + ValueError: If the local architecture is not supported. + """ + arch = platform.machine() + + if arch in ["x86_64", "AMD64"]: + return "amd64" + if arch in ["arm64", "aarch64", "ARM64"]: + return "arm64" + raise ValueError(f"Unsupported architecture: {arch}") + + +def get_local_cache_dir() -> Path: + """Get the local cache directory path for dreadnode platform files. + + Returns: + Path: Path to the local cache directory (~//platform). + """ + return DEFAULT_LOCAL_STORAGE_DIR / "platform" + + +def get_cli_version() -> str: + """Get the version of the dreadnode CLI package. + + Returns: + str | None: The version string of the dreadnode package, or None if not found. + """ + return importlib.metadata.version("dreadnode") + + +def confirm_with_context(action: str) -> bool: + """Prompt the user for confirmation with a formatted action message. + + Args: + action: The action description to display in the confirmation prompt. + + Returns: + bool: True if the user confirms, False otherwise. Defaults to False. + """ + return Confirm.ask(f"[bold blue]{action}[/bold blue]", default=False) + + +def get_available_local_versions() -> LocalVersionsSchema: + """Get all available local platform versions from the manifest file. + + Creates the manifest file with an empty schema if it doesn't exist. + + Returns: + LocalVersionsSchema: Schema containing all available local platform versions. + """ + try: + local_cache_dir = get_local_cache_dir() + manifest_path = local_cache_dir / VERSIONS_MANIFEST + with manifest_path.open(encoding="utf-8") as f: + versions_manifest_data = json.load(f) + return LocalVersionsSchema(**versions_manifest_data) + except FileNotFoundError: + # create the file + local_cache_dir = get_local_cache_dir() + manifest_path = local_cache_dir / VERSIONS_MANIFEST + manifest_path.parent.mkdir(parents=True, exist_ok=True) + blank_schema = LocalVersionsSchema(versions=[]) + with manifest_path.open(encoding="utf-8", mode="w") as f: + json.dump(blank_schema.model_dump(), f) + return blank_schema + + +def get_current_version() -> LocalVersionSchema | None: + """Get the currently active local platform version. + + Returns: + LocalVersionSchema | None: The current version schema if one is marked as current, + None otherwise. + """ + available_local_versions = get_available_local_versions() + if not available_local_versions.versions: + return None + for version in available_local_versions.versions: + if version.current: + return version + return None + + +def mark_current_version(current_version: LocalVersionSchema) -> None: + """Mark a specific version as the current active version. + + Updates the versions manifest to mark the specified version as current + and all others as not current. + + Args: + current_version: The version to mark as current. + """ + available_local_versions = get_available_local_versions() + for available_version in available_local_versions.versions: + if available_version.tag == current_version.tag: + available_version.current = True + else: + available_version.current = False + + local_cache_dir = get_local_cache_dir() + manifest_path = local_cache_dir / VERSIONS_MANIFEST + with manifest_path.open(encoding="utf-8", mode="w") as f: + json.dump(available_local_versions.model_dump(), f, indent=2) + + +def create_local_latest_tag() -> str: + """Create a latest tag string for the local architecture. + + Returns: + str: A tag in the format "latest-{arch}" where arch is the local architecture. + """ + arch = _get_local_arch() + return f"latest-{arch}" + + +def get_semver_from_tag(tag: str) -> str: + """Extract semantic version from a tag by removing architecture suffix. + + Args: + tag: The tag string that may contain an architecture suffix. + + Returns: + str: The tag with any supported architecture suffix removed. + """ + for arch in SUPPORTED_ARCHITECTURES: + if arch in tag: + return tag.replace(f"-{arch}", "") + return tag + + +def newer_remote_version(local_version: str, remote_version: str) -> bool: + """Check if the remote version is newer than the local version. + + Args: + local_version: The local version string in semantic version format. + remote_version: The remote version string in semantic version format. + + Returns: + bool: True if the remote version is newer than the local version, False otherwise. + """ + # compare the semvers of two versions to see if the remote is "newer" + return Version(remote_version) > Version(local_version) diff --git a/dreadnode/constants.py b/dreadnode/constants.py index f2888347..47c028d9 100644 --- a/dreadnode/constants.py +++ b/dreadnode/constants.py @@ -5,6 +5,8 @@ # Defaults # +# name of the default local storage path +DEFAULT_LOCAL_STORAGE_DIR = pathlib.Path.home() / ".dreadnode" # name of the default server profile DEFAULT_PROFILE_NAME = "main" # default poll interval for the authentication flow @@ -20,7 +22,7 @@ # default server URL DEFAULT_SERVER_URL = f"https://platform.{DEFAULT_PLATFORM_BASE_DOMAIN}" # default local directory for dreadnode objects -DEFAULT_LOCAL_OBJECT_DIR = ".dreadnode/objects" +DEFAULT_LOCAL_OBJECT_DIR = f"{DEFAULT_LOCAL_STORAGE_DIR}/objects" # default docker registry subdomain DEFAULT_DOCKER_REGISTRY_SUBDOMAIN = "registry" # default docker registry local port @@ -54,7 +56,7 @@ # path to the user configuration file USER_CONFIG_PATH = pathlib.Path( # allow overriding the user config file via env variable - os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "config" + os.getenv("DREADNODE_USER_CONFIG_FILE") or DEFAULT_LOCAL_STORAGE_DIR / "config" ) # Default values for the file system credential management diff --git a/dreadnode/main.py b/dreadnode/main.py index f2c05e4b..b64cba04 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -23,6 +23,7 @@ from dreadnode.api.client import ApiClient from dreadnode.artifact.credential_manager import CredentialManager from dreadnode.constants import ( + DEFAULT_LOCAL_STORAGE_DIR, DEFAULT_SERVER_URL, ENV_API_KEY, ENV_API_TOKEN, @@ -138,7 +139,7 @@ def __init__( self._logfire.config.ignore_no_config = True self._fs: AbstractFileSystem = LocalFileSystem(auto_mkdir=True) - self._fs_prefix: str = ".dreadnode/storage/" + self._fs_prefix: str = f"{DEFAULT_LOCAL_STORAGE_DIR}/storage/" self._initialized = False