From ff3db8e0ddba9595af0e6cce79711d311a1db0fe Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 6 Feb 2025 14:07:03 +0500 Subject: [PATCH] Support --name for all configurations --- src/dstack/_internal/cli/commands/apply.py | 13 ++++++++----- .../cli/services/configurators/base.py | 6 ++++-- .../cli/services/configurators/fleet.py | 17 +++++++++++++---- .../cli/services/configurators/gateway.py | 15 +++++++++++++++ .../_internal/cli/services/configurators/run.py | 11 ++++++----- .../cli/services/configurators/volume.py | 15 +++++++++++++++ src/dstack/_internal/cli/services/repos.py | 6 +++--- .../cli/services/configurators/test_profile.py | 6 +++--- 8 files changed, 67 insertions(+), 22 deletions(-) diff --git a/src/dstack/_internal/cli/commands/apply.py b/src/dstack/_internal/cli/commands/apply.py index ea60b11eb..7283969b0 100644 --- a/src/dstack/_internal/cli/commands/apply.py +++ b/src/dstack/_internal/cli/commands/apply.py @@ -6,12 +6,12 @@ get_apply_configurator_class, load_apply_configuration, ) -from dstack._internal.cli.services.configurators.base import BaseApplyConfigurator from dstack._internal.cli.services.repos import ( init_default_virtual_repo, init_repo, register_init_repo_args, ) +from dstack._internal.cli.utils.common import console from dstack._internal.core.errors import CLIError from dstack._internal.core.models.configurations import ApplyConfigurationType @@ -92,10 +92,13 @@ def _command(self, args: argparse.Namespace): configurator_class = get_apply_configurator_class( ApplyConfigurationType(args.help) ) - else: - configurator_class = BaseApplyConfigurator - configurator_class.register_args(self._parser) + configurator_class.register_args(self._parser) + self._parser.print_help() + return self._parser.print_help() + console.print( + "\nType `dstack apply -h CONFIGURATION_TYPE` to see configuration-specific options.\n" + ) return super()._command(args) @@ -129,5 +132,5 @@ def _command(self, args: argparse.Namespace): repo=repo, ) except KeyboardInterrupt: - print("\nOperation interrupted by user. Exiting...") + console.print("\nOperation interrupted by user. Exiting...") exit(0) diff --git a/src/dstack/_internal/cli/services/configurators/base.py b/src/dstack/_internal/cli/services/configurators/base.py index b3bb15fd3..6bc72df55 100644 --- a/src/dstack/_internal/cli/services/configurators/base.py +++ b/src/dstack/_internal/cli/services/configurators/base.py @@ -1,7 +1,7 @@ import argparse import os from abc import ABC, abstractmethod -from typing import List, Optional, cast +from typing import List, Optional, Union, cast from dstack._internal.cli.services.args import env_var from dstack._internal.core.errors import ConfigurationError @@ -14,6 +14,8 @@ from dstack._internal.core.models.repos.base import Repo from dstack.api._public import Client +ArgsParser = Union[argparse._ArgumentGroup, argparse.ArgumentParser] + class BaseApplyConfigurator(ABC): TYPE: ApplyConfigurationType @@ -82,7 +84,7 @@ def register_args(cls, parser: argparse.ArgumentParser): class ApplyEnvVarsConfiguratorMixin: @classmethod - def register_env_args(cls, parser: argparse.ArgumentParser): + def register_env_args(cls, parser: ArgsParser): parser.add_argument( "-e", "--env", diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py index 7f2d96c83..7eea9fb70 100644 --- a/src/dstack/_internal/cli/services/configurators/fleet.py +++ b/src/dstack/_internal/cli/services/configurators/fleet.py @@ -41,10 +41,6 @@ class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator): TYPE: ApplyConfigurationType = ApplyConfigurationType.FLEET - @classmethod - def register_args(cls, parser: argparse.ArgumentParser): - cls.register_env_args(parser) - def apply_configuration( self, conf: FleetConfiguration, @@ -185,7 +181,20 @@ def delete_configuration( console.print(f"Fleet [code]{conf.name}[/] deleted") + @classmethod + def register_args(cls, parser: argparse.ArgumentParser): + configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options") + configuration_group.add_argument( + "-n", + "--name", + dest="name", + help="The fleet name", + ) + cls.register_env_args(configuration_group) + def apply_args(self, conf: FleetConfiguration, args: argparse.Namespace, unknown: List[str]): + if args.name: + conf.name = args.name self.apply_env_vars(conf.env, args) if conf.ssh_config is None and conf.env: raise ConfigurationError("`env` is currently supported for SSH fleets only") diff --git a/src/dstack/_internal/cli/services/configurators/gateway.py b/src/dstack/_internal/cli/services/configurators/gateway.py index 45fb06cba..a2411f974 100644 --- a/src/dstack/_internal/cli/services/configurators/gateway.py +++ b/src/dstack/_internal/cli/services/configurators/gateway.py @@ -39,6 +39,7 @@ def apply_configuration( unknown_args: List[str], repo: Optional[Repo] = None, ): + self.apply_args(conf, configurator_args, unknown_args) spec = GatewaySpec( configuration=conf, configuration_path=configuration_path, @@ -170,6 +171,20 @@ def delete_configuration( console.print(f"Gateway [code]{conf.name}[/] deleted") + @classmethod + def register_args(cls, parser: argparse.ArgumentParser): + configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options") + configuration_group.add_argument( + "-n", + "--name", + dest="name", + help="The gateway name", + ) + + def apply_args(self, conf: GatewayConfiguration, args: argparse.Namespace, unknown: List[str]): + if args.name: + conf.name = args.name + def _get_plan(api: Client, spec: GatewaySpec) -> GatewayPlan: # TODO: Implement server-side /get_plan with an offer included diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index c7eeaa6fe..911e80a6d 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -298,20 +298,21 @@ def delete_configuration( @classmethod def register_args(cls, parser: argparse.ArgumentParser): - parser.add_argument( + configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options") + configuration_group.add_argument( "-n", "--name", dest="run_name", help="The name of the run. If not specified, a random name is assigned", ) - parser.add_argument( + configuration_group.add_argument( "--max-offers", help="Number of offers to show in the run plan", type=int, default=3, ) - cls.register_env_args(parser) - parser.add_argument( + cls.register_env_args(configuration_group) + configuration_group.add_argument( "--gpu", type=gpu_spec, help="Request GPU for the run. " @@ -319,7 +320,7 @@ def register_args(cls, parser: argparse.ArgumentParser): dest="gpu_spec", metavar="SPEC", ) - parser.add_argument( + configuration_group.add_argument( "--disk", type=disk_spec, help="Request the size range of disk for the run. Example [code]--disk 100GB..[/].", diff --git a/src/dstack/_internal/cli/services/configurators/volume.py b/src/dstack/_internal/cli/services/configurators/volume.py index d8bfb2436..77e34b788 100644 --- a/src/dstack/_internal/cli/services/configurators/volume.py +++ b/src/dstack/_internal/cli/services/configurators/volume.py @@ -38,6 +38,7 @@ def apply_configuration( unknown_args: List[str], repo: Optional[Repo] = None, ): + self.apply_args(conf, configurator_args, unknown_args) spec = VolumeSpec( configuration=conf, configuration_path=configuration_path, @@ -158,6 +159,20 @@ def delete_configuration( console.print(f"Volume [code]{conf.name}[/] deleted") + @classmethod + def register_args(cls, parser: argparse.ArgumentParser): + configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options") + configuration_group.add_argument( + "-n", + "--name", + dest="name", + help="The volume name", + ) + + def apply_args(self, conf: VolumeConfiguration, args: argparse.Namespace, unknown: List[str]): + if args.name: + conf.name = args.name + def _get_plan(api: Client, spec: VolumeSpec) -> VolumePlan: # TODO: Implement server-side /get_plan with an offer included diff --git a/src/dstack/_internal/cli/services/repos.py b/src/dstack/_internal/cli/services/repos.py index e6b3e0f90..e9b1d4432 100644 --- a/src/dstack/_internal/cli/services/repos.py +++ b/src/dstack/_internal/cli/services/repos.py @@ -1,7 +1,7 @@ -from argparse import ArgumentParser, _ArgumentGroup from pathlib import Path -from typing import Optional, Union +from typing import Optional +from dstack._internal.cli.services.configurators.base import ArgsParser from dstack._internal.core.errors import CLIError from dstack._internal.core.models.repos.base import Repo, RepoType from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo, RepoError @@ -12,7 +12,7 @@ from dstack.api._public import Client -def register_init_repo_args(parser: Union[ArgumentParser, _ArgumentGroup]): +def register_init_repo_args(parser: ArgsParser): parser.add_argument( "-t", "--token", diff --git a/src/tests/_internal/cli/services/configurators/test_profile.py b/src/tests/_internal/cli/services/configurators/test_profile.py index d9a047bdf..df84f7247 100644 --- a/src/tests/_internal/cli/services/configurators/test_profile.py +++ b/src/tests/_internal/cli/services/configurators/test_profile.py @@ -73,6 +73,6 @@ def apply_args(profile: Profile, args: List[str]) -> Tuple[Profile, argparse.Nam parser = argparse.ArgumentParser() register_profile_args(parser) profile = profile.copy(deep=True) # to avoid modifying the original profile - args = parser.parse_args(args) - apply_profile_args(args, profile) - return profile, args + parsed_args = parser.parse_args(args) + apply_profile_args(parsed_args, profile) + return profile, parsed_args