Skip to content

Commit

Permalink
update CLI to support passing extra_mounts flag
Browse files Browse the repository at this point in the history
  • Loading branch information
kyujin-cho committed May 24, 2024
1 parent dc4935e commit d02ae54
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 10 deletions.
29 changes: 28 additions & 1 deletion src/ai/backend/client/cli/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

from ai.backend.cli.main import main
from ai.backend.cli.types import ExitCode
from ai.backend.client.cli.session.execute import prepare_env_arg, prepare_resource_arg
from ai.backend.client.cli.session.execute import (
prepare_env_arg,
prepare_mount_arg,
prepare_resource_arg,
)
from ai.backend.client.compat import asyncio_run
from ai.backend.client.session import AsyncSession, Session
from ai.backend.common.arch import DEFAULT_IMAGE_ARCH
Expand Down Expand Up @@ -137,6 +141,24 @@ def info(ctx: CLIContext, service_name_or_id: str):
multiple=True,
help="Environment variable (may appear multiple times)",
)
@click.option(
"-v",
"--volume",
"-m",
"--mount",
"mount",
metavar="NAME[=PATH] or NAME[:PATH]",
type=str,
multiple=True,
help=(
"Name or ID of virtual folders to mount."
"If path is not provided, virtual folder will be mounted under /home/work. "
"When the target path is relative, it is placed under /home/work "
"with auto-created parent directories if any. "
"Absolute paths are mounted as-is, but it is prohibited to "
"override the predefined Linux system directories."
),
)
# extra options
@click.option(
"--bootstrap-script",
Expand Down Expand Up @@ -244,6 +266,7 @@ def create(
model_version: Optional[str],
model_mount_destination: Optional[str],
env: Sequence[str],
mount: Sequence[str],
startup_command: Optional[str],
resources: Sequence[str],
resource_opts: Sequence[str],
Expand All @@ -266,12 +289,16 @@ def create(
"""
envs = prepare_env_arg(env)
mount, mount_map, mount_options = prepare_mount_arg(mount, escape=True)
parsed_resources = prepare_resource_arg(resources)
parsed_resource_opts = prepare_resource_arg(resource_opts)
body = {
"service_name": name,
"model_version": model_version,
"envs": envs,
"extra_mounts": mount,
"extra_mount_map": mount_map,
"extra_mount_options": mount_options,
"startup_command": startup_command,
"resources": parsed_resources,
"resource_opts": parsed_resource_opts,
Expand Down
51 changes: 42 additions & 9 deletions src/ai/backend/client/func/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from faker import Faker

from ai.backend.client.exceptions import BackendClientError
from ai.backend.client.output.fields import service_fields
from ai.backend.client.output.types import FieldSpec, PaginatedResult
from ai.backend.client.pagination import fetch_paginated_result
Expand Down Expand Up @@ -90,6 +91,9 @@ async def create(
model_id_or_name: str,
initial_session_count: int,
*,
extra_mounts: Sequence[str] = [],
extra_mount_map: Mapping[str, str] = {},
extra_mount_options: Mapping[str, Mapping[str, str]] = {},
service_name: Optional[str] = None,
model_version: Optional[str] = None,
dependencies: Optional[Sequence[str]] = None,
Expand Down Expand Up @@ -143,6 +147,43 @@ async def create(
faker = Faker()
service_name = f"bai-serve-{faker.user_name()}"

if extra_mounts:
vfolder_id_to_name: dict[UUID, str] = {}
vfolder_name_to_id: dict[str, UUID] = {}

rqst = Request("GET", "/folders")
async with rqst.fetch() as resp:
body = await resp.json()
for folder_info in body:
vfolder_id_to_name[UUID(folder_info["id"])] = folder_info["name"]
vfolder_name_to_id[folder_info["name"]] = UUID(folder_info["id"])

extra_mount_body = {}

for mount in extra_mounts:
try:
vfolder_id = UUID(mount)
if vfolder_id not in vfolder_id_to_name:
raise BackendClientError(f"VFolder (id: {vfolder_id}) not found")
except ValueError:
if mount not in vfolder_name_to_id:
raise BackendClientError(f"VFolder (name: {vfolder_id}) not found")
vfolder_id = vfolder_name_to_id[mount]
extra_mount_body[str(vfolder_id)] = {
"mount_destination": extra_mount_map.get(mount),
"type": extra_mount_options.get(mount, {}).get("type"),
}
model_config = {
"model": model_id_or_name,
"model_mount_destination": model_mount_destination,
"extra_mounts": extra_mount_body,
"environ": envs,
"scaling_group": scaling_group,
"resources": resources,
"resource_opts": resource_opts,
}
if model_version:
model_config["model_version"] = model_version
rqst = Request("POST", "/services")
rqst.set_json({
"name": service_name,
Expand All @@ -158,15 +199,7 @@ async def create(
"bootstrap_script": bootstrap_script,
"owner_access_key": owner_access_key,
"open_to_public": expose_to_public,
"config": {
"model": model_id_or_name,
"model_version": model_version,
"model_mount_destination": model_mount_destination,
"environ": envs,
"scaling_group": scaling_group,
"resources": resources,
"resource_opts": resource_opts,
},
"config": model_config,
})
async with rqst.fetch() as resp:
body = await resp.json()
Expand Down

0 comments on commit d02ae54

Please sign in to comment.