Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/dstack/_internal/cli/commands/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
get_apply_configurator_class,
load_apply_configuration,
)
from dstack._internal.cli.services.configurators.base import BaseApplyConfigurator
from dstack._internal.cli.services.repos import (
init_default_virtual_repo,
init_repo,
register_init_repo_args,
)
from dstack._internal.cli.utils.common import console
from dstack._internal.core.errors import CLIError
from dstack._internal.core.models.configurations import ApplyConfigurationType

Expand Down Expand Up @@ -92,10 +92,13 @@ def _command(self, args: argparse.Namespace):
configurator_class = get_apply_configurator_class(
ApplyConfigurationType(args.help)
)
else:
configurator_class = BaseApplyConfigurator
configurator_class.register_args(self._parser)
configurator_class.register_args(self._parser)
self._parser.print_help()
return
self._parser.print_help()
console.print(
"\nType `dstack apply -h CONFIGURATION_TYPE` to see configuration-specific options.\n"
)
return

super()._command(args)
Expand Down Expand Up @@ -129,5 +132,5 @@ def _command(self, args: argparse.Namespace):
repo=repo,
)
except KeyboardInterrupt:
print("\nOperation interrupted by user. Exiting...")
console.print("\nOperation interrupted by user. Exiting...")
exit(0)
6 changes: 4 additions & 2 deletions src/dstack/_internal/cli/services/configurators/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import os
from abc import ABC, abstractmethod
from typing import List, Optional, cast
from typing import List, Optional, Union, cast

from dstack._internal.cli.services.args import env_var
from dstack._internal.core.errors import ConfigurationError
Expand All @@ -14,6 +14,8 @@
from dstack._internal.core.models.repos.base import Repo
from dstack.api._public import Client

ArgsParser = Union[argparse._ArgumentGroup, argparse.ArgumentParser]


class BaseApplyConfigurator(ABC):
TYPE: ApplyConfigurationType
Expand Down Expand Up @@ -82,7 +84,7 @@ def register_args(cls, parser: argparse.ArgumentParser):

class ApplyEnvVarsConfiguratorMixin:
@classmethod
def register_env_args(cls, parser: argparse.ArgumentParser):
def register_env_args(cls, parser: ArgsParser):
parser.add_argument(
"-e",
"--env",
Expand Down
17 changes: 13 additions & 4 deletions src/dstack/_internal/cli/services/configurators/fleet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@
class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
TYPE: ApplyConfigurationType = ApplyConfigurationType.FLEET

@classmethod
def register_args(cls, parser: argparse.ArgumentParser):
cls.register_env_args(parser)

def apply_configuration(
self,
conf: FleetConfiguration,
Expand Down Expand Up @@ -185,7 +181,20 @@ def delete_configuration(

console.print(f"Fleet [code]{conf.name}[/] deleted")

@classmethod
def register_args(cls, parser: argparse.ArgumentParser):
configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options")
configuration_group.add_argument(
"-n",
"--name",
dest="name",
help="The fleet name",
)
cls.register_env_args(configuration_group)

def apply_args(self, conf: FleetConfiguration, args: argparse.Namespace, unknown: List[str]):
if args.name:
conf.name = args.name
self.apply_env_vars(conf.env, args)
if conf.ssh_config is None and conf.env:
raise ConfigurationError("`env` is currently supported for SSH fleets only")
Expand Down
15 changes: 15 additions & 0 deletions src/dstack/_internal/cli/services/configurators/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def apply_configuration(
unknown_args: List[str],
repo: Optional[Repo] = None,
):
self.apply_args(conf, configurator_args, unknown_args)
spec = GatewaySpec(
configuration=conf,
configuration_path=configuration_path,
Expand Down Expand Up @@ -170,6 +171,20 @@ def delete_configuration(

console.print(f"Gateway [code]{conf.name}[/] deleted")

@classmethod
def register_args(cls, parser: argparse.ArgumentParser):
configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options")
configuration_group.add_argument(
"-n",
"--name",
dest="name",
help="The gateway name",
)

def apply_args(self, conf: GatewayConfiguration, args: argparse.Namespace, unknown: List[str]):
if args.name:
conf.name = args.name


def _get_plan(api: Client, spec: GatewaySpec) -> GatewayPlan:
# TODO: Implement server-side /get_plan with an offer included
Expand Down
11 changes: 6 additions & 5 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,28 +298,29 @@ def delete_configuration(

@classmethod
def register_args(cls, parser: argparse.ArgumentParser):
parser.add_argument(
configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options")
configuration_group.add_argument(
"-n",
"--name",
dest="run_name",
help="The name of the run. If not specified, a random name is assigned",
)
parser.add_argument(
configuration_group.add_argument(
"--max-offers",
help="Number of offers to show in the run plan",
type=int,
default=3,
)
cls.register_env_args(parser)
parser.add_argument(
cls.register_env_args(configuration_group)
configuration_group.add_argument(
"--gpu",
type=gpu_spec,
help="Request GPU for the run. "
"The format is [code]NAME[/]:[code]COUNT[/]:[code]MEMORY[/] (all parts are optional)",
dest="gpu_spec",
metavar="SPEC",
)
parser.add_argument(
configuration_group.add_argument(
"--disk",
type=disk_spec,
help="Request the size range of disk for the run. Example [code]--disk 100GB..[/].",
Expand Down
15 changes: 15 additions & 0 deletions src/dstack/_internal/cli/services/configurators/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def apply_configuration(
unknown_args: List[str],
repo: Optional[Repo] = None,
):
self.apply_args(conf, configurator_args, unknown_args)
spec = VolumeSpec(
configuration=conf,
configuration_path=configuration_path,
Expand Down Expand Up @@ -158,6 +159,20 @@ def delete_configuration(

console.print(f"Volume [code]{conf.name}[/] deleted")

@classmethod
def register_args(cls, parser: argparse.ArgumentParser):
configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options")
configuration_group.add_argument(
"-n",
"--name",
dest="name",
help="The volume name",
)

def apply_args(self, conf: VolumeConfiguration, args: argparse.Namespace, unknown: List[str]):
if args.name:
conf.name = args.name


def _get_plan(api: Client, spec: VolumeSpec) -> VolumePlan:
# TODO: Implement server-side /get_plan with an offer included
Expand Down
6 changes: 3 additions & 3 deletions src/dstack/_internal/cli/services/repos.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from argparse import ArgumentParser, _ArgumentGroup
from pathlib import Path
from typing import Optional, Union
from typing import Optional

from dstack._internal.cli.services.configurators.base import ArgsParser
from dstack._internal.core.errors import CLIError
from dstack._internal.core.models.repos.base import Repo, RepoType
from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo, RepoError
Expand All @@ -12,7 +12,7 @@
from dstack.api._public import Client


def register_init_repo_args(parser: Union[ArgumentParser, _ArgumentGroup]):
def register_init_repo_args(parser: ArgsParser):
parser.add_argument(
"-t",
"--token",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,6 @@ def apply_args(profile: Profile, args: List[str]) -> Tuple[Profile, argparse.Nam
parser = argparse.ArgumentParser()
register_profile_args(parser)
profile = profile.copy(deep=True) # to avoid modifying the original profile
args = parser.parse_args(args)
apply_profile_args(args, profile)
return profile, args
parsed_args = parser.parse_args(args)
apply_profile_args(parsed_args, profile)
return profile, parsed_args