Skip to content

Commit

Permalink
Externalize container configuration for deployment
Browse files Browse the repository at this point in the history
  • Loading branch information
mayoor committed May 3, 2024
1 parent 27e12c8 commit f5a367e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 21 deletions.
53 changes: 40 additions & 13 deletions ads/aqua/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from enum import Enum
import json
import logging
from dataclasses import dataclass, field, asdict
Expand All @@ -17,6 +18,7 @@
from ads.aqua.utils import (
UNKNOWN,
MODEL_BY_REFERENCE_OSS_PATH_KEY,
get_container_config,
load_config,
get_container_image,
UNKNOWN_DICT,
Expand Down Expand Up @@ -44,12 +46,22 @@
AQUA_CONFIG_FOLDER,
AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS,
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
AQUA_SERVED_MODEL_NAME,
)
from ads.common.object_storage_details import ObjectStorageDetails
from ads.telemetry import telemetry


class ContainerSpec:
"""
Class to hold to hold keys within the container spec.
"""

CLI_PARM = "cliParam"
SERVER_PORT = "serverPort"
HEALTH_CHECK_PORT = "healthCheckPort"
ENV_VARS = "envVars"


@dataclass
class ShapeInfo:
instance_shape: str = None
Expand Down Expand Up @@ -197,8 +209,8 @@ def create(
description: str = None,
bandwidth_mbps: int = None,
web_concurrency: int = None,
server_port: int = 8080,
health_check_port: int = 8080,
server_port: int = None,
health_check_port: int = None,
env_var: Dict = None,
) -> "AquaDeployment":
"""
Expand Down Expand Up @@ -232,9 +244,9 @@ def create(
The number of worker processes/threads to handle incoming requests
with_bucket_uri(bucket_uri)
Sets the bucket uri when uploading large size model.
server_port: (int). Defaults to 8080.
server_port: (int).
The server port for docker container image.
health_check_port: (int). Defaults to 8080.
health_check_port: (int).
The health check port for docker container image.
env_var : dict, optional
Environment variable for the deployment, by default None.
Expand All @@ -244,6 +256,7 @@ def create(
An Aqua deployment instance
"""
# TODO validate if the service model has no artifact and if it requires import step before deployment.
# Create a model catalog entry in the user compartment
aqua_model = AquaModelApp().create(
model_id=model_id, compartment_id=compartment_id, project_id=project_id
Expand Down Expand Up @@ -308,14 +321,6 @@ def create(
os_path = ObjectStorageDetails.from_path(model_path_prefix)
model_path_prefix = os_path.filepath.rstrip("/")

env_var.update({"BASE_MODEL": f"{model_path_prefix}"})
params = f"--served-model-name {AQUA_SERVED_MODEL_NAME} --seed 42 "
if vllm_params:
params += vllm_params
env_var.update({"PARAMS": params})
env_var.update({"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions"})
env_var.update({"MODEL_DEPLOY_ENABLE_STREAMING": "true"})

if is_fine_tuned_model:
_, fine_tune_output_path = get_model_by_reference_paths(
aqua_model.model_file_description
Expand Down Expand Up @@ -364,6 +369,28 @@ def create(
f"Aqua Image used for deploying {aqua_model.id} : {container_image}"
)

# Fetch the startup cli command for the container
# container_index.json will have "container-spec" section which will provide the cli params for a given container family
container_config = get_container_config()
container_spec = container_config.get("container-spec", {}).get(
container_type_key, {}
)
params = container_spec.get(ContainerSpec.CLI_PARM, "")
server_port = server_port or container_spec.get(
ContainerSpec.SERVER_PORT
) # Give precendece to the input parameter
health_check_port = health_check_port or container_spec.get(
ContainerSpec.HEALTH_CHECK_PORT
) # Give precendece to the input parameter

env_var.update({"BASE_MODEL": f"{model_path_prefix}"})
env_var.update({"MODEL": f"{model_path_prefix}"})
if vllm_params:
params = f"{params} {vllm_params}"
env_var.update({"PARAMS": params})
for env in container_spec.get(ContainerSpec.ENV_VARS, []):
if isinstance(env, dict):
env_var.update(env)
# Start model deployment
# configure model deployment infrastructure
# todo : any other infrastructure params needed?
Expand Down
24 changes: 16 additions & 8 deletions ads/aqua/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,19 @@ def _build_job_identifier(
return AquaResourceIdentifier()


def container_config_path():
return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"


def get_container_config():
config = load_config(
file_path=container_config_path(),
config_file_name=CONTAINER_INDEX,
)

return config


def get_container_image(
config_file_name: str = None, container_type: str = None
) -> str:
Expand All @@ -539,14 +552,8 @@ def get_container_image(
A dict of allowed configs.
"""

config_file_name = (
f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
)

config = load_config(
file_path=config_file_name,
config_file_name=CONTAINER_INDEX,
)
config = config_file_name or get_container_config()
config_file_name = container_config_path()

if container_type not in config:
raise AquaValueError(
Expand Down Expand Up @@ -771,6 +778,7 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile}"
try:
logger.info(f"Running: {command}")
subprocess.check_call(shlex.split(command))
except subprocess.CalledProcessError as e:
logger.error(
Expand Down

0 comments on commit f5a367e

Please sign in to comment.