diff --git a/docs/docs/reference/dstack.yml/service.md b/docs/docs/reference/dstack.yml/service.md index 24e1c62db..58e272f02 100644 --- a/docs/docs/reference/dstack.yml/service.md +++ b/docs/docs/reference/dstack.yml/service.md @@ -14,7 +14,7 @@ The `service` configuration type allows running [services](../../concepts/servic === "OpenAI" - #SCHEMA# dstack._internal.core.models.gateways.OpenAIChatModel + #SCHEMA# dstack.api.OpenAIChatModel overrides: show_root_heading: false type: @@ -25,7 +25,7 @@ The `service` configuration type allows running [services](../../concepts/servic > TGI provides an OpenAI-compatible API starting with version 1.4.0, so models served by TGI can be defined with `format: openai` too. - #SCHEMA# dstack._internal.core.models.gateways.TGIChatModel + #SCHEMA# dstack.api.TGIChatModel overrides: show_root_heading: false type: diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index bffe3b174..41241ed9d 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -9,11 +9,12 @@ from dstack._internal.core.models.common import CoreModel, Duration, RegistryAuth from dstack._internal.core.models.envs import Env from dstack._internal.core.models.fleets import FleetConfiguration -from dstack._internal.core.models.gateways import AnyModel, GatewayConfiguration, OpenAIChatModel +from dstack._internal.core.models.gateways import GatewayConfiguration from dstack._internal.core.models.profiles import ProfileParams from dstack._internal.core.models.repos.base import Repo from dstack._internal.core.models.repos.virtual import VirtualRepo from dstack._internal.core.models.resources import Range, ResourcesSpec +from dstack._internal.core.models.services import AnyModel, OpenAIChatModel from dstack._internal.core.models.unix import UnixUser from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index db39ba3c1..21a32102d 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -8,8 +8,6 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel -# TODO(#1595): refactor into different modules: gateway-specific and proxy-specific - class GatewayStatus(str, Enum): SUBMITTED = "submitted" @@ -110,69 +108,3 @@ class GatewayProvisioningData(CoreModel): availability_zone: Optional[str] = None hostname: Optional[str] = None backend_data: Optional[str] = None # backend-specific data in json - - -class BaseChatModel(CoreModel): - type: Annotated[Literal["chat"], Field(description="The type of the model")] = "chat" - name: Annotated[str, Field(description="The name of the model")] - format: Annotated[ - str, Field(description="The serving format. Supported values include `openai` and `tgi`") - ] - - -class TGIChatModel(BaseChatModel): - """ - Mapping of the model for the OpenAI-compatible endpoint. - - Attributes: - type (str): The type of the model, e.g. "chat" - name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint. - format (str): The format of the model, e.g. "tgi" if the model is served with HuggingFace's Text Generation Inference. - chat_template (Optional[str]): The custom prompt template for the model. If not specified, the default prompt template from the HuggingFace Hub configuration will be used. - eos_token (Optional[str]): The custom end of sentence token. If not specified, the default end of sentence token from the HuggingFace Hub configuration will be used. - """ - - format: Annotated[ - Literal["tgi"], Field(description="The serving format. Must be set to `tgi`") - ] - chat_template: Annotated[ - Optional[str], - Field( - description=( - "The custom prompt template for the model." - " If not specified, the default prompt template" - " from the HuggingFace Hub configuration will be used" - ) - ), - ] = None # will be set before registering the service - eos_token: Annotated[ - Optional[str], - Field( - description=( - "The custom end of sentence token." - " If not specified, the default end of sentence token" - " from the HuggingFace Hub configuration will be used" - ) - ), - ] = None - - -class OpenAIChatModel(BaseChatModel): - """ - Mapping of the model for the OpenAI-compatible endpoint. - - Attributes: - type (str): The type of the model, e.g. "chat" - name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint. - format (str): The format of the model, i.e. "openai". - prefix (str): The `base_url` prefix: `http://hostname/{prefix}/chat/completions`. Defaults to `/v1`. - """ - - format: Annotated[ - Literal["openai"], Field(description="The serving format. Must be set to `openai`") - ] - prefix: Annotated[str, Field(description="The `base_url` prefix (after hostname)")] = "/v1" - - -ChatModel = Annotated[Union[TGIChatModel, OpenAIChatModel], Field(discriminator="format")] -AnyModel = Union[ChatModel] # embeddings and etc. diff --git a/src/dstack/_internal/core/models/services.py b/src/dstack/_internal/core/models/services.py new file mode 100644 index 000000000..8aa639577 --- /dev/null +++ b/src/dstack/_internal/core/models/services.py @@ -0,0 +1,76 @@ +""" +Data structures related to `type: service` runs. +""" + +from typing import Optional, Union + +from pydantic import Field +from typing_extensions import Annotated, Literal + +from dstack._internal.core.models.common import CoreModel + + +class BaseChatModel(CoreModel): + type: Annotated[Literal["chat"], Field(description="The type of the model")] = "chat" + name: Annotated[str, Field(description="The name of the model")] + format: Annotated[ + str, Field(description="The serving format. Supported values include `openai` and `tgi`") + ] + + +class TGIChatModel(BaseChatModel): + """ + Mapping of the model for the OpenAI-compatible endpoint. + + Attributes: + type (str): The type of the model, e.g. "chat" + name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint. + format (str): The format of the model, e.g. "tgi" if the model is served with HuggingFace's Text Generation Inference. + chat_template (Optional[str]): The custom prompt template for the model. If not specified, the default prompt template from the HuggingFace Hub configuration will be used. + eos_token (Optional[str]): The custom end of sentence token. If not specified, the default end of sentence token from the HuggingFace Hub configuration will be used. + """ + + format: Annotated[ + Literal["tgi"], Field(description="The serving format. Must be set to `tgi`") + ] + chat_template: Annotated[ + Optional[str], + Field( + description=( + "The custom prompt template for the model." + " If not specified, the default prompt template" + " from the HuggingFace Hub configuration will be used" + ) + ), + ] = None # will be set before registering the service + eos_token: Annotated[ + Optional[str], + Field( + description=( + "The custom end of sentence token." + " If not specified, the default end of sentence token" + " from the HuggingFace Hub configuration will be used" + ) + ), + ] = None + + +class OpenAIChatModel(BaseChatModel): + """ + Mapping of the model for the OpenAI-compatible endpoint. + + Attributes: + type (str): The type of the model, e.g. "chat" + name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint. + format (str): The format of the model, i.e. "openai". + prefix (str): The `base_url` prefix: `http://hostname/{prefix}/chat/completions`. Defaults to `/v1`. + """ + + format: Annotated[ + Literal["openai"], Field(description="The serving format. Must be set to `openai`") + ] + prefix: Annotated[str, Field(description="The `base_url` prefix (after hostname)")] = "/v1" + + +ChatModel = Annotated[Union[TGIChatModel, OpenAIChatModel], Field(discriminator="format")] +AnyModel = Union[ChatModel] # embeddings and etc. diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 9172c0116..3828ad609 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -6,7 +6,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload -import dstack._internal.server.services.gateways as gateways from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT from dstack._internal.core.errors import GatewayError from dstack._internal.core.models.backends.base import BackendType @@ -32,6 +31,7 @@ ) from dstack._internal.server.schemas.runner import TaskStatus from dstack._internal.server.services import logs as logs_services +from dstack._internal.server.services import services from dstack._internal.server.services.jobs import ( find_job, get_job_runtime_data, @@ -313,7 +313,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): and run.run_spec.configuration.type == "service" ): try: - await gateways.register_replica(session, run_model.gateway_id, run, job_model) + await services.register_replica(session, run_model.gateway_id, run, job_model) except GatewayError as e: logger.warning( "%s: failed to register service replica: %s, age=%s", diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 54abdbdc1..33eb9f950 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import joinedload, selectinload import dstack._internal.server.services.gateways as gateways -import dstack._internal.server.services.gateways.autoscalers as autoscalers +import dstack._internal.server.services.services.autoscalers as autoscalers from dstack._internal.core.errors import ServerError from dstack._internal.core.models.profiles import RetryEvent from dstack._internal.core.models.runs import ( diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index de385162c..151f496eb 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -3,14 +3,12 @@ import uuid from datetime import timedelta, timezone from typing import List, Optional, Sequence -from urllib.parse import urlparse import httpx from sqlalchemy import func, select, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, selectinload +from sqlalchemy.orm import selectinload -import dstack._internal.server.services.jobs as jobs_services import dstack._internal.utils.random_names as random_names from dstack._internal.core.backends import ( BACKENDS_WITH_GATEWAY_SUPPORT, @@ -28,8 +26,6 @@ SSHError, ) from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.common import is_core_model_instance -from dstack._internal.core.models.configurations import SERVICE_HTTPS_DEFAULT, ServiceConfiguration from dstack._internal.core.models.gateways import ( Gateway, GatewayComputeConfiguration, @@ -37,42 +33,26 @@ GatewayStatus, LetsEncryptGatewayCertificate, ) -from dstack._internal.core.models.runs import ( - Run, - RunSpec, - ServiceModelSpec, - ServiceSpec, -) from dstack._internal.core.services import validate_dstack_resource_name from dstack._internal.server import settings from dstack._internal.server.db import get_db -from dstack._internal.server.models import ( - GatewayComputeModel, - GatewayModel, - JobModel, - ProjectModel, - RunModel, -) +from dstack._internal.server.models import GatewayComputeModel, GatewayModel, ProjectModel from dstack._internal.server.services.backends import ( get_project_backend_by_type_or_error, get_project_backend_with_model_by_type_or_error, ) from dstack._internal.server.services.gateways.connection import GatewayConnection -from dstack._internal.server.services.gateways.options import get_service_options from dstack._internal.server.services.gateways.pool import gateway_connections_pool from dstack._internal.server.services.locking import ( advisory_lock_ctx, get_locker, string_to_lock_id, ) -from dstack._internal.server.services.logging import fmt from dstack._internal.server.utils.common import gather_map_async from dstack._internal.utils.common import get_current_datetime, run_async from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes from dstack._internal.utils.logging import get_logger -# TODO(#1595): refactor into different modules: gateway-specific and proxy-specific - logger = get_logger(__name__) @@ -352,201 +332,6 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) -> return name -async def register_service(session: AsyncSession, run_model: RunModel, run_spec: RunSpec): - assert is_core_model_instance(run_spec.configuration, ServiceConfiguration) - - if isinstance(run_spec.configuration.gateway, str): - gateway = await get_project_gateway_model_by_name( - session=session, project=run_model.project, name=run_spec.configuration.gateway - ) - if gateway is None: - raise ResourceNotExistsError( - f"Gateway {run_spec.configuration.gateway} does not exist" - ) - elif run_spec.configuration.gateway == False: - gateway = None - else: - gateway = run_model.project.default_gateway - - if gateway is not None: - service_spec = await _register_service_in_gateway(session, run_model, run_spec, gateway) - run_model.gateway = gateway - elif not settings.FORBID_SERVICES_WITHOUT_GATEWAY: - service_spec = _register_service_in_server(run_model, run_spec) - else: - raise ResourceNotExistsError( - "This dstack-server installation forbids services without a gateway." - " Please configure a gateway." - ) - run_model.service_spec = service_spec.json() - - -async def _register_service_in_gateway( - session: AsyncSession, run_model: RunModel, run_spec: RunSpec, gateway: GatewayModel -) -> ServiceSpec: - if gateway.gateway_compute is None: - raise ServerClientError("Gateway has no instance associated with it") - - if gateway.status != GatewayStatus.RUNNING: - raise ServerClientError("Gateway status is not running") - - gateway_configuration = get_gateway_configuration(gateway) - service_https = _get_service_https(run_spec, gateway_configuration) - service_protocol = "https" if service_https else "http" - - if service_https and gateway_configuration.certificate is None: - raise ServerClientError( - "Cannot run HTTPS service on gateway with no SSL certificates configured" - ) - - gateway_https = _get_gateway_https(gateway_configuration) - gateway_protocol = "https" if gateway_https else "http" - - wildcard_domain = gateway.wildcard_domain.lstrip("*.") if gateway.wildcard_domain else None - if wildcard_domain is None: - raise ServerClientError("Domain is required for gateway") - service_spec = get_service_spec( - configuration=run_spec.configuration, - service_url=f"{service_protocol}://{run_model.run_name}.{wildcard_domain}", - model_url=f"{gateway_protocol}://gateway.{wildcard_domain}", - ) - - conn = await get_or_add_gateway_connection(session, gateway.id) - try: - logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url) - async with conn.client() as client: - await client.register_service( - project=run_model.project.name, - run_name=run_model.run_name, - domain=urlparse(service_spec.url).hostname, - service_https=service_https, - gateway_https=gateway_https, - auth=run_spec.configuration.auth, - client_max_body_size=settings.DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE, - options=service_spec.options, - ssh_private_key=run_model.project.ssh_private_key, - ) - logger.info("%s: service is registered as %s", fmt(run_model), service_spec.url) - except SSHError: - raise ServerClientError("Gateway tunnel is not working") - except httpx.RequestError as e: - logger.debug("Gateway request failed", exc_info=True) - raise GatewayError(f"Gateway is not working: {e!r}") - - return service_spec - - -def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> ServiceSpec: - if run_spec.configuration.https != SERVICE_HTTPS_DEFAULT: - # Note: if the user sets `https: `, it will be ignored silently - # TODO: in 0.19, make `https` Optional to be able to tell if it was set or omitted - raise ServerClientError( - "The `https` configuration property is not applicable when running services without a gateway. " - "Please configure a gateway or remove the `https` property from the service configuration" - ) - if run_spec.configuration.replicas.min != run_spec.configuration.replicas.max: - raise ServerClientError( - "Auto-scaling is not yet supported when running services without a gateway. " - "Please configure a gateway or set `replicas` to a fixed value in the service configuration" - ) - return get_service_spec( - configuration=run_spec.configuration, - service_url=f"/proxy/services/{run_model.project.name}/{run_model.run_name}/", - model_url=f"/proxy/models/{run_model.project.name}/", - ) - - -def get_service_spec( - configuration: ServiceConfiguration, service_url: str, model_url: str -) -> ServiceSpec: - service_spec = ServiceSpec(url=service_url) - if configuration.model is not None: - service_spec.model = ServiceModelSpec( - name=configuration.model.name, - base_url=model_url, - type=configuration.model.type, - ) - service_spec.options = get_service_options(configuration) - return service_spec - - -async def register_replica( - session: AsyncSession, gateway_id: Optional[uuid.UUID], run: Run, job_model: JobModel -): - if gateway_id is None: # in-server proxy - return - conn = await get_or_add_gateway_connection(session, gateway_id) - job_submission = jobs_services.job_model_to_job_submission(job_model) - try: - logger.debug("%s: registering replica for service %s", fmt(job_model), run.id.hex) - async with conn.client() as client: - await client.register_replica( - run=run, - job_submission=job_submission, - ) - logger.info("%s: replica is registered for service %s", fmt(job_model), run.id.hex) - except (httpx.RequestError, SSHError) as e: - logger.debug("Gateway request failed", exc_info=True) - raise GatewayError(repr(e)) - - -async def unregister_service(session: AsyncSession, run_model: RunModel): - if run_model.gateway_id is None: # in-server proxy - return - conn = await get_or_add_gateway_connection(session, run_model.gateway_id) - res = await session.execute( - select(ProjectModel).where(ProjectModel.id == run_model.project_id) - ) - project = res.scalar_one() - try: - logger.debug("%s: unregistering service", fmt(run_model)) - async with conn.client() as client: - await client.unregister_service( - project=project.name, - run_name=run_model.run_name, - ) - logger.debug("%s: service is unregistered", fmt(run_model)) - except GatewayError as e: - # ignore if service is not registered - logger.warning("%s: unregistering service: %s", fmt(run_model), e) - except (httpx.RequestError, SSHError) as e: - logger.debug("Gateway request failed", exc_info=True) - raise GatewayError(repr(e)) - - -async def unregister_replica(session: AsyncSession, job_model: JobModel): - res = await session.execute( - select(RunModel) - .where(RunModel.id == job_model.run_id) - .options(joinedload(RunModel.project).joinedload(ProjectModel.backends)) - ) - run_model = res.unique().scalar_one() - if run_model.gateway_id is None: - # not a service or served by in-server proxy - return - - conn = await get_or_add_gateway_connection(session, run_model.gateway_id) - try: - logger.debug( - "%s: unregistering replica from service %s", fmt(job_model), job_model.run_id.hex - ) - async with conn.client() as client: - await client.unregister_replica( - project=run_model.project.name, - run_name=run_model.run_name, - job_id=job_model.id, - ) - logger.info( - "%s: replica is unregistered from service %s", fmt(job_model), job_model.run_id.hex - ) - except GatewayError as e: - # ignore if replica is not registered - logger.warning("%s: unregistering replica from service: %s", fmt(job_model), e) - except (httpx.RequestError, SSHError) as e: - logger.debug("Gateway request failed", exc_info=True) - raise GatewayError(repr(e)) - - async def get_or_add_gateway_connection( session: AsyncSession, gateway_id: uuid.UUID ) -> GatewayConnection: @@ -772,19 +557,3 @@ def _validate_gateway_configuration(configuration: GatewayConfiguration): ) if configuration.certificate.type == "acm" and configuration.backend != BackendType.AWS: raise ServerClientError("acm certificate type is supported for aws backend only") - - -def _get_service_https(run_spec: RunSpec, configuration: GatewayConfiguration) -> bool: - if not run_spec.configuration.https: - return False - if configuration.certificate is not None and configuration.certificate.type == "acm": - return False - return True - - -def _get_gateway_https(configuration: GatewayConfiguration) -> bool: - if configuration.certificate is not None and configuration.certificate.type == "acm": - return False - if configuration.certificate is not None and configuration.certificate.type == "lets-encrypt": - return True - return False diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 5101ad565..9892a4837 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -9,7 +9,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -import dstack._internal.server.services.gateways as gateways from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT from dstack._internal.core.errors import BackendError, ResourceNotExistsError, SSHError from dstack._internal.core.models.backends.base import BackendType @@ -31,6 +30,7 @@ RunModel, VolumeModel, ) +from dstack._internal.server.services import services from dstack._internal.server.services.backends import get_project_backend_by_type from dstack._internal.server.services.jobs.configurators.base import JobConfigurator from dstack._internal.server.services.jobs.configurators.dev import DevEnvironmentJobConfigurator @@ -264,7 +264,7 @@ async def process_terminating_job(session: AsyncSession, job_model: JobModel): instance.name, instance.status.name, ) - await gateways.unregister_replica( + await services.unregister_replica( session, job_model ) # TODO(egor-s) ensure always runs diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index 4d6864832..8b7ef4577 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -9,7 +9,6 @@ from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT from dstack._internal.core.models.common import is_core_model_instance from dstack._internal.core.models.configurations import ServiceConfiguration -from dstack._internal.core.models.gateways import AnyModel from dstack._internal.core.models.instances import SSHConnectionParams from dstack._internal.core.models.runs import ( JobProvisioningData, @@ -18,6 +17,7 @@ RunStatus, ServiceSpec, ) +from dstack._internal.core.models.services import AnyModel from dstack._internal.proxy.lib.models import ( AnyModelFormat, ChatModel, diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 1b585549e..bd1bbb07b 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -9,7 +9,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload, selectinload -import dstack._internal.server.services.gateways as gateways import dstack._internal.utils.common as common_utils from dstack._internal.core.errors import ( RepoDoesNotExistError, @@ -62,6 +61,7 @@ VolumeModel, ) from dstack._internal.server.services import repos as repos_services +from dstack._internal.server.services import services from dstack._internal.server.services import volumes as volumes_services from dstack._internal.server.services.docker import is_valid_docker_volume_target from dstack._internal.server.services.jobs import ( @@ -473,7 +473,7 @@ async def submit_run( replicas = 1 if run_spec.configuration.type == "service": replicas = run_spec.configuration.replicas.min - await gateways.register_service(session, run_model, run_spec) + await services.register_service(session, run_model, run_spec) for replica_num in range(replicas): jobs = await get_jobs_from_run_spec(run_spec, replica_num=replica_num) @@ -1031,7 +1031,7 @@ async def process_terminating_run(session: AsyncSession, run: RunModel): if unfinished_jobs_count == 0: if run.service_spec is not None: try: - await gateways.unregister_service(session, run) + await services.unregister_service(session, run) except Exception as e: logger.warning("%s: failed to unregister service: %s", fmt(run), repr(e)) run.status = run.termination_reason.to_status() diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py new file mode 100644 index 000000000..8235d0ea1 --- /dev/null +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -0,0 +1,247 @@ +""" +Application logic related to `type: service` runs. +""" + +import uuid +from typing import Optional +from urllib.parse import urlparse + +import httpx +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +import dstack._internal.server.services.jobs as jobs_services +from dstack._internal.core.errors import ( + GatewayError, + ResourceNotExistsError, + ServerClientError, + SSHError, +) +from dstack._internal.core.models.common import is_core_model_instance +from dstack._internal.core.models.configurations import SERVICE_HTTPS_DEFAULT, ServiceConfiguration +from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus +from dstack._internal.core.models.runs import Run, RunSpec, ServiceModelSpec, ServiceSpec +from dstack._internal.server import settings +from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel +from dstack._internal.server.services.gateways import ( + get_gateway_configuration, + get_or_add_gateway_connection, + get_project_gateway_model_by_name, +) +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.services.options import get_service_options +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +async def register_service(session: AsyncSession, run_model: RunModel, run_spec: RunSpec): + assert is_core_model_instance(run_spec.configuration, ServiceConfiguration) + + if isinstance(run_spec.configuration.gateway, str): + gateway = await get_project_gateway_model_by_name( + session=session, project=run_model.project, name=run_spec.configuration.gateway + ) + if gateway is None: + raise ResourceNotExistsError( + f"Gateway {run_spec.configuration.gateway} does not exist" + ) + elif run_spec.configuration.gateway == False: + gateway = None + else: + gateway = run_model.project.default_gateway + + if gateway is not None: + service_spec = await _register_service_in_gateway(session, run_model, run_spec, gateway) + run_model.gateway = gateway + elif not settings.FORBID_SERVICES_WITHOUT_GATEWAY: + service_spec = _register_service_in_server(run_model, run_spec) + else: + raise ResourceNotExistsError( + "This dstack-server installation forbids services without a gateway." + " Please configure a gateway." + ) + run_model.service_spec = service_spec.json() + + +async def _register_service_in_gateway( + session: AsyncSession, run_model: RunModel, run_spec: RunSpec, gateway: GatewayModel +) -> ServiceSpec: + if gateway.gateway_compute is None: + raise ServerClientError("Gateway has no instance associated with it") + + if gateway.status != GatewayStatus.RUNNING: + raise ServerClientError("Gateway status is not running") + + gateway_configuration = get_gateway_configuration(gateway) + service_https = _get_service_https(run_spec, gateway_configuration) + service_protocol = "https" if service_https else "http" + + if service_https and gateway_configuration.certificate is None: + raise ServerClientError( + "Cannot run HTTPS service on gateway with no SSL certificates configured" + ) + + gateway_https = _get_gateway_https(gateway_configuration) + gateway_protocol = "https" if gateway_https else "http" + + wildcard_domain = gateway.wildcard_domain.lstrip("*.") if gateway.wildcard_domain else None + if wildcard_domain is None: + raise ServerClientError("Domain is required for gateway") + service_spec = get_service_spec( + configuration=run_spec.configuration, + service_url=f"{service_protocol}://{run_model.run_name}.{wildcard_domain}", + model_url=f"{gateway_protocol}://gateway.{wildcard_domain}", + ) + + conn = await get_or_add_gateway_connection(session, gateway.id) + try: + logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url) + async with conn.client() as client: + await client.register_service( + project=run_model.project.name, + run_name=run_model.run_name, + domain=urlparse(service_spec.url).hostname, + service_https=service_https, + gateway_https=gateway_https, + auth=run_spec.configuration.auth, + client_max_body_size=settings.DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE, + options=service_spec.options, + ssh_private_key=run_model.project.ssh_private_key, + ) + logger.info("%s: service is registered as %s", fmt(run_model), service_spec.url) + except SSHError: + raise ServerClientError("Gateway tunnel is not working") + except httpx.RequestError as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(f"Gateway is not working: {e!r}") + + return service_spec + + +def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> ServiceSpec: + if run_spec.configuration.https != SERVICE_HTTPS_DEFAULT: + # Note: if the user sets `https: `, it will be ignored silently + # TODO: in 0.19, make `https` Optional to be able to tell if it was set or omitted + raise ServerClientError( + "The `https` configuration property is not applicable when running services without a gateway. " + "Please configure a gateway or remove the `https` property from the service configuration" + ) + if run_spec.configuration.replicas.min != run_spec.configuration.replicas.max: + raise ServerClientError( + "Auto-scaling is not yet supported when running services without a gateway. " + "Please configure a gateway or set `replicas` to a fixed value in the service configuration" + ) + return get_service_spec( + configuration=run_spec.configuration, + service_url=f"/proxy/services/{run_model.project.name}/{run_model.run_name}/", + model_url=f"/proxy/models/{run_model.project.name}/", + ) + + +def get_service_spec( + configuration: ServiceConfiguration, service_url: str, model_url: str +) -> ServiceSpec: + service_spec = ServiceSpec(url=service_url) + if configuration.model is not None: + service_spec.model = ServiceModelSpec( + name=configuration.model.name, + base_url=model_url, + type=configuration.model.type, + ) + service_spec.options = get_service_options(configuration) + return service_spec + + +async def register_replica( + session: AsyncSession, gateway_id: Optional[uuid.UUID], run: Run, job_model: JobModel +): + if gateway_id is None: # in-server proxy + return + conn = await get_or_add_gateway_connection(session, gateway_id) + job_submission = jobs_services.job_model_to_job_submission(job_model) + try: + logger.debug("%s: registering replica for service %s", fmt(job_model), run.id.hex) + async with conn.client() as client: + await client.register_replica( + run=run, + job_submission=job_submission, + ) + logger.info("%s: replica is registered for service %s", fmt(job_model), run.id.hex) + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(repr(e)) + + +async def unregister_service(session: AsyncSession, run_model: RunModel): + if run_model.gateway_id is None: # in-server proxy + return + conn = await get_or_add_gateway_connection(session, run_model.gateway_id) + res = await session.execute( + select(ProjectModel).where(ProjectModel.id == run_model.project_id) + ) + project = res.scalar_one() + try: + logger.debug("%s: unregistering service", fmt(run_model)) + async with conn.client() as client: + await client.unregister_service( + project=project.name, + run_name=run_model.run_name, + ) + logger.debug("%s: service is unregistered", fmt(run_model)) + except GatewayError as e: + # ignore if service is not registered + logger.warning("%s: unregistering service: %s", fmt(run_model), e) + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(repr(e)) + + +async def unregister_replica(session: AsyncSession, job_model: JobModel): + res = await session.execute( + select(RunModel) + .where(RunModel.id == job_model.run_id) + .options(joinedload(RunModel.project).joinedload(ProjectModel.backends)) + ) + run_model = res.unique().scalar_one() + if run_model.gateway_id is None: + # not a service or served by in-server proxy + return + + conn = await get_or_add_gateway_connection(session, run_model.gateway_id) + try: + logger.debug( + "%s: unregistering replica from service %s", fmt(job_model), job_model.run_id.hex + ) + async with conn.client() as client: + await client.unregister_replica( + project=run_model.project.name, + run_name=run_model.run_name, + job_id=job_model.id, + ) + logger.info( + "%s: replica is unregistered from service %s", fmt(job_model), job_model.run_id.hex + ) + except GatewayError as e: + # ignore if replica is not registered + logger.warning("%s: unregistering replica from service: %s", fmt(job_model), e) + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(repr(e)) + + +def _get_service_https(run_spec: RunSpec, configuration: GatewayConfiguration) -> bool: + if not run_spec.configuration.https: + return False + if configuration.certificate is not None and configuration.certificate.type == "acm": + return False + return True + + +def _get_gateway_https(configuration: GatewayConfiguration) -> bool: + if configuration.certificate is not None and configuration.certificate.type == "acm": + return False + if configuration.certificate is not None and configuration.certificate.type == "lets-encrypt": + return True + return False diff --git a/src/dstack/_internal/server/services/gateways/autoscalers.py b/src/dstack/_internal/server/services/services/autoscalers.py similarity index 100% rename from src/dstack/_internal/server/services/gateways/autoscalers.py rename to src/dstack/_internal/server/services/services/autoscalers.py diff --git a/src/dstack/_internal/server/services/gateways/options.py b/src/dstack/_internal/server/services/services/options.py similarity index 97% rename from src/dstack/_internal/server/services/gateways/options.py rename to src/dstack/_internal/server/services/services/options.py index 6c75acb2c..3e26be39b 100644 --- a/src/dstack/_internal/server/services/gateways/options.py +++ b/src/dstack/_internal/server/services/services/options.py @@ -4,7 +4,7 @@ from dstack._internal.core.errors import ServerClientError from dstack._internal.core.models.configurations import ServiceConfiguration -from dstack._internal.core.models.gateways import AnyModel +from dstack._internal.core.models.services import AnyModel def complete_service_model(model_info: AnyModel, env: Dict[str, str]): diff --git a/src/dstack/api/__init__.py b/src/dstack/api/__init__.py index 137080a8f..d1fe55a07 100644 --- a/src/dstack/api/__init__.py +++ b/src/dstack/api/__init__.py @@ -7,7 +7,6 @@ ServiceConfiguration as _ServiceConfiguration, ) from dstack._internal.core.models.configurations import TaskConfiguration as _TaskConfiguration -from dstack._internal.core.models.gateways import OpenAIChatModel, TGIChatModel from dstack._internal.core.models.repos.local import LocalRepo from dstack._internal.core.models.repos.remote import RemoteRepo from dstack._internal.core.models.repos.virtual import VirtualRepo @@ -15,6 +14,7 @@ from dstack._internal.core.models.resources import DiskSpec as Disk from dstack._internal.core.models.resources import GPUSpec as GPU from dstack._internal.core.models.resources import ResourcesSpec as Resources +from dstack._internal.core.models.services import OpenAIChatModel, TGIChatModel from dstack._internal.core.services.ssh.ports import PortUsedError from dstack.api._public import BackendCollection, Client, RepoCollection, RunCollection from dstack.api._public.backends import Backend diff --git a/src/tests/_internal/server/services/gateways/__init__.py b/src/tests/_internal/server/services/services/__init__.py similarity index 100% rename from src/tests/_internal/server/services/gateways/__init__.py rename to src/tests/_internal/server/services/services/__init__.py diff --git a/src/tests/_internal/server/services/gateways/test_autoscalers.py b/src/tests/_internal/server/services/services/test_autoscalers.py similarity index 98% rename from src/tests/_internal/server/services/gateways/test_autoscalers.py rename to src/tests/_internal/server/services/services/test_autoscalers.py index 8a86f7349..1e65b6842 100644 --- a/src/tests/_internal/server/services/gateways/test_autoscalers.py +++ b/src/tests/_internal/server/services/services/test_autoscalers.py @@ -4,7 +4,7 @@ import pytest from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats, Stat -from dstack._internal.server.services.gateways.autoscalers import ReplicaInfo, RPSAutoscaler +from dstack._internal.server.services.services.autoscalers import ReplicaInfo, RPSAutoscaler @pytest.fixture