diff --git a/src/ai/backend/client/cli/service.py b/src/ai/backend/client/cli/service.py index e55a36c97f..4a0db670f3 100644 --- a/src/ai/backend/client/cli/service.py +++ b/src/ai/backend/client/cli/service.py @@ -7,7 +7,11 @@ from ai.backend.cli.main import main from ai.backend.cli.types import ExitCode -from ai.backend.client.cli.session.execute import prepare_env_arg, prepare_resource_arg +from ai.backend.client.cli.session.execute import ( + prepare_env_arg, + prepare_mount_arg, + prepare_resource_arg, +) from ai.backend.client.compat import asyncio_run from ai.backend.client.session import AsyncSession, Session from ai.backend.common.arch import DEFAULT_IMAGE_ARCH @@ -137,6 +141,24 @@ def info(ctx: CLIContext, service_name_or_id: str): multiple=True, help="Environment variable (may appear multiple times)", ) +@click.option( + "-v", + "--volume", + "-m", + "--mount", + "mount", + metavar="NAME[=PATH] or NAME[:PATH]", + type=str, + multiple=True, + help=( + "Name or ID of virtual folders to mount." + "If path is not provided, virtual folder will be mounted under /home/work. " + "When the target path is relative, it is placed under /home/work " + "with auto-created parent directories if any. " + "Absolute paths are mounted as-is, but it is prohibited to " + "override the predefined Linux system directories." + ), +) # extra options @click.option( "--bootstrap-script", @@ -244,6 +266,7 @@ def create( model_version: Optional[str], model_mount_destination: Optional[str], env: Sequence[str], + mount: Sequence[str], startup_command: Optional[str], resources: Sequence[str], resource_opts: Sequence[str], @@ -266,12 +289,16 @@ def create( """ envs = prepare_env_arg(env) + mount, mount_map, mount_options = prepare_mount_arg(mount, escape=True) parsed_resources = prepare_resource_arg(resources) parsed_resource_opts = prepare_resource_arg(resource_opts) body = { "service_name": name, "model_version": model_version, "envs": envs, + "extra_mounts": mount, + "extra_mount_map": mount_map, + "extra_mount_options": mount_options, "startup_command": startup_command, "resources": parsed_resources, "resource_opts": parsed_resource_opts, diff --git a/src/ai/backend/client/func/service.py b/src/ai/backend/client/func/service.py index 2a75d75b71..bbb2021e30 100644 --- a/src/ai/backend/client/func/service.py +++ b/src/ai/backend/client/func/service.py @@ -4,6 +4,7 @@ from faker import Faker +from ai.backend.client.exceptions import BackendClientError from ai.backend.client.output.fields import service_fields from ai.backend.client.output.types import FieldSpec, PaginatedResult from ai.backend.client.pagination import fetch_paginated_result @@ -90,6 +91,9 @@ async def create( model_id_or_name: str, initial_session_count: int, *, + extra_mounts: Sequence[str] = [], + extra_mount_map: Mapping[str, str] = {}, + extra_mount_options: Mapping[str, Mapping[str, str]] = {}, service_name: Optional[str] = None, model_version: Optional[str] = None, dependencies: Optional[Sequence[str]] = None, @@ -143,6 +147,43 @@ async def create( faker = Faker() service_name = f"bai-serve-{faker.user_name()}" + if extra_mounts: + vfolder_id_to_name: dict[UUID, str] = {} + vfolder_name_to_id: dict[str, UUID] = {} + + rqst = Request("GET", "/folders") + async with rqst.fetch() as resp: + body = await resp.json() + for folder_info in body: + vfolder_id_to_name[UUID(folder_info["id"])] = folder_info["name"] + vfolder_name_to_id[folder_info["name"]] = UUID(folder_info["id"]) + + extra_mount_body = {} + + for mount in extra_mounts: + try: + vfolder_id = UUID(mount) + if vfolder_id not in vfolder_id_to_name: + raise BackendClientError(f"VFolder (id: {vfolder_id}) not found") + except ValueError: + if mount not in vfolder_name_to_id: + raise BackendClientError(f"VFolder (name: {vfolder_id}) not found") + vfolder_id = vfolder_name_to_id[mount] + extra_mount_body[str(vfolder_id)] = { + "mount_destination": extra_mount_map.get(mount), + "type": extra_mount_options.get(mount, {}).get("type"), + } + model_config = { + "model": model_id_or_name, + "model_mount_destination": model_mount_destination, + "extra_mounts": extra_mount_body, + "environ": envs, + "scaling_group": scaling_group, + "resources": resources, + "resource_opts": resource_opts, + } + if model_version: + model_config["model_version"] = model_version rqst = Request("POST", "/services") rqst.set_json({ "name": service_name, @@ -158,15 +199,7 @@ async def create( "bootstrap_script": bootstrap_script, "owner_access_key": owner_access_key, "open_to_public": expose_to_public, - "config": { - "model": model_id_or_name, - "model_version": model_version, - "model_mount_destination": model_mount_destination, - "environ": envs, - "scaling_group": scaling_group, - "resources": resources, - "resource_opts": resource_opts, - }, + "config": model_config, }) async with rqst.fetch() as resp: body = await resp.json()