Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/dstack/_internal/cli/commands/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
)
from dstack._internal.cli.utils.gateway import get_gateways_table, print_gateways_table
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.gateways import GatewayConfiguration
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)


class GatewayCommand(APIBaseCommand):
Expand Down Expand Up @@ -41,7 +45,9 @@ def _register(self):
)

create_parser = subparsers.add_parser(
"create", help="Add a gateway", formatter_class=self._parser.formatter_class
"create",
help="Add a gateway. Deprecated in favor of `dstack apply` with gateway configuration.",
formatter_class=self._parser.formatter_class,
)
create_parser.set_defaults(subfunc=self._create)
create_parser.add_argument(
Expand Down Expand Up @@ -100,10 +106,16 @@ def _list(self, args: argparse.Namespace):
pass

def _create(self, args: argparse.Namespace):
logger.warning(
"`dstack gateway create` is deperecated in favor of `dstack apply` with gateway configurations."
)
with console.status("Creating gateway..."):
gateway = self.api.client.gateways.create(
self.api.project, args.name, BackendType(args.backend), args.region
configuration = GatewayConfiguration(
name=args.name,
backend=BackendType(args.backend),
region=args.region,
)
gateway = self.api.client.gateways.create(self.api.project, configuration)
if args.set_default:
self.api.client.gateways.set_default(self.api.project, gateway.name)
if args.domain:
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/cli/utils/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_gateways_table(
for gateway in gateways:
row = {
"NAME": gateway.name,
"BACKEND": f"{gateway.backend.value} ({gateway.region})",
"BACKEND": f"{gateway.configuration.backend.value} ({gateway.configuration.region})",
"HOSTNAME": gateway.hostname,
"DOMAIN": gateway.wildcard_domain,
"DEFAULT": "✓" if gateway.default else "",
Expand Down
3 changes: 1 addition & 2 deletions src/dstack/_internal/core/models/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,7 @@ def _merged_profile(cls, values) -> Dict:


class Fleet(CoreModel):
# id is Optional for backward compatibility within 0.18.x
id: Optional[uuid.UUID] = None
id: uuid.UUID
name: str
project_name: str
spec: FleetSpec
Expand Down
6 changes: 3 additions & 3 deletions src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ class Gateway(CoreModel):
# The ip address of the gateway instance
ip_address: Optional[str]
instance_id: Optional[str]
wildcard_domain: Optional[str]
default: bool
# TODO: configuration fields are duplicated on top-level for backward compatibility with 0.18.x
# Remove in 0.19
# Remove after 0.19
backend: BackendType
region: str
default: bool
wildcard_domain: Optional[str]


class GatewayPlan(CoreModel):
Expand Down
4 changes: 1 addition & 3 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,7 @@ class RunPlan(CoreModel):
run_spec: RunSpec
job_plans: List[JobPlan]
current_resource: Optional[Run] = None
# Optional for backward-compatibility with 0.18.x servers
# TODO: make required in 0.19
action: Optional[ApplyAction] = None
action: ApplyAction


class ApplyRunPlanInput(CoreModel):
Expand Down
4 changes: 1 addition & 3 deletions src/dstack/_internal/core/models/volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ class VolumeAttachment(CoreModel):
class Volume(CoreModel):
id: uuid.UUID
name: str
# Default user to "" for client backward compatibility (old 0.18 servers).
# TODO: Remove in 0.19
user: str = ""
user: str
project_name: str
configuration: VolumeConfiguration
external: bool
Expand Down
27 changes: 2 additions & 25 deletions src/dstack/_internal/server/schemas/gateways.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,11 @@
from typing import Dict, List, Optional
from typing import List

from pydantic import root_validator

from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
from dstack._internal.core.models.gateways import GatewayConfiguration


class CreateGatewayRequest(CoreModel):
name: Optional[str]
backend_type: Optional[BackendType]
region: Optional[str]
configuration: Optional[GatewayConfiguration]

@root_validator
def fill_configuration(cls, values: Dict) -> Dict:
if values.get("configuration", None) is not None:
return values
backend_type = values.get("backend_type", None)
region = values.get("region", None)
if backend_type is None:
raise ValueError("backend_type must be specified")
if region is None:
raise ValueError("region must be specified")
values["configuration"] = GatewayConfiguration(
name=values.get("name", None),
backend=backend_type,
region=region,
)
return values
configuration: GatewayConfiguration


class GetGatewayRequest(CoreModel):
Expand Down
17 changes: 3 additions & 14 deletions src/dstack/api/server/_gateways.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import List, Optional
from typing import List

from pydantic import parse_obj_as

from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.gateways import Gateway, GatewayConfiguration
from dstack._internal.server.schemas.gateways import (
CreateGatewayRequest,
Expand All @@ -24,22 +23,12 @@ def get(self, project_name: str, gateway_name: str) -> Gateway:
resp = self._request(f"/api/project/{project_name}/gateways/get", body=body.json())
return parse_obj_as(Gateway.__response__, resp.json())

# gateway_name, backend_type, region are left for backward-compatibility with 0.18.x
# TODO: Remove in 0.19
def create(
self,
project_name: str,
gateway_name: Optional[str] = None,
backend_type: Optional[BackendType] = None,
region: Optional[str] = None,
configuration: Optional[GatewayConfiguration] = None,
configuration: GatewayConfiguration,
) -> Gateway:
body = CreateGatewayRequest(
name=gateway_name,
backend_type=backend_type,
region=region,
configuration=configuration,
)
body = CreateGatewayRequest(configuration=configuration)
resp = self._request(f"/api/project/{project_name}/gateways/create", body=body.json())
return parse_obj_as(Gateway.__response__, resp.json())

Expand Down
27 changes: 24 additions & 3 deletions src/tests/_internal/server/routers/test_gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,14 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn
backend = await create_backend(session, project.id, backend_type=BackendType.AWS)
response = await client.post(
f"/api/project/{project.name}/gateways/create",
json={"name": "test", "backend_type": "aws", "region": "us"},
json={
"configuration": {
"type": "gateway",
"name": "test",
"backend": "aws",
"region": "us",
},
},
headers=get_auth_headers(user.token),
)
assert response.status_code == 200
Expand Down Expand Up @@ -218,7 +225,14 @@ async def test_create_gateway_without_name(
g.return_value = "random-name"
response = await client.post(
f"/api/project/{project.name}/gateways/create",
json={"name": None, "backend_type": "aws", "region": "us"},
json={
"configuration": {
"type": "gateway",
"name": None,
"backend": "aws",
"region": "us",
},
},
headers=get_auth_headers(user.token),
)
g.assert_called_once()
Expand Down Expand Up @@ -259,7 +273,14 @@ async def test_create_gateway_missing_backend(
)
response = await client.post(
f"/api/project/{project.name}/gateways/create",
json={"name": "test", "backend_type": "aws", "region": "us"},
json={
"configuration": {
"type": "gateway",
"name": "test",
"backend": "aws",
"region": "us",
},
},
headers=get_auth_headers(user.token),
)
assert response.status_code == 400
Expand Down