diff --git a/ads/aqua/deployment.py b/ads/aqua/deployment.py index ab68c0e64..247686705 100644 --- a/ads/aqua/deployment.py +++ b/ads/aqua/deployment.py @@ -3,6 +3,7 @@ # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from enum import Enum import json import logging from dataclasses import dataclass, field, asdict @@ -17,6 +18,7 @@ from ads.aqua.utils import ( UNKNOWN, MODEL_BY_REFERENCE_OSS_PATH_KEY, + get_container_config, load_config, get_container_image, UNKNOWN_DICT, @@ -44,12 +46,23 @@ AQUA_CONFIG_FOLDER, AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS, AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME, - AQUA_SERVED_MODEL_NAME, ) from ads.common.object_storage_details import ObjectStorageDetails from ads.telemetry import telemetry +class ContainerSpec: + """ + Class to hold to hold keys within the container spec. + """ + + CONTAINER_SPEC = "containerSpec" + CLI_PARM = "cliParam" + SERVER_PORT = "serverPort" + HEALTH_CHECK_PORT = "healthCheckPort" + ENV_VARS = "envVars" + + @dataclass class ShapeInfo: instance_shape: str = None @@ -197,8 +210,8 @@ def create( description: str = None, bandwidth_mbps: int = None, web_concurrency: int = None, - server_port: int = 8080, - health_check_port: int = 8080, + server_port: int = None, + health_check_port: int = None, env_var: Dict = None, ) -> "AquaDeployment": """ @@ -232,9 +245,9 @@ def create( The number of worker processes/threads to handle incoming requests with_bucket_uri(bucket_uri) Sets the bucket uri when uploading large size model. - server_port: (int). Defaults to 8080. + server_port: (int). The server port for docker container image. - health_check_port: (int). Defaults to 8080. + health_check_port: (int). The health check port for docker container image. env_var : dict, optional Environment variable for the deployment, by default None. @@ -244,6 +257,7 @@ def create( An Aqua deployment instance """ + # TODO validate if the service model has no artifact and if it requires import step before deployment. # Create a model catalog entry in the user compartment aqua_model = AquaModelApp().create( model_id=model_id, compartment_id=compartment_id, project_id=project_id @@ -308,14 +322,6 @@ def create( os_path = ObjectStorageDetails.from_path(model_path_prefix) model_path_prefix = os_path.filepath.rstrip("/") - env_var.update({"BASE_MODEL": f"{model_path_prefix}"}) - params = f"--served-model-name {AQUA_SERVED_MODEL_NAME} --seed 42 " - if vllm_params: - params += vllm_params - env_var.update({"PARAMS": params}) - env_var.update({"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions"}) - env_var.update({"MODEL_DEPLOY_ENABLE_STREAMING": "true"}) - if is_fine_tuned_model: _, fine_tune_output_path = get_model_by_reference_paths( aqua_model.model_file_description @@ -364,6 +370,28 @@ def create( f"Aqua Image used for deploying {aqua_model.id} : {container_image}" ) + # Fetch the startup cli command for the container + # container_index.json will have "containerSpec" section which will provide the cli params for a given container family + container_config = get_container_config() + container_spec = container_config.get(ContainerSpec.CONTAINER_SPEC, {}).get( + container_type_key, {} + ) + params = container_spec.get(ContainerSpec.CLI_PARM, "") + server_port = server_port or container_spec.get( + ContainerSpec.SERVER_PORT + ) # Give precendece to the input parameter + health_check_port = health_check_port or container_spec.get( + ContainerSpec.HEALTH_CHECK_PORT + ) # Give precendece to the input parameter + + env_var.update({"BASE_MODEL": f"{model_path_prefix}"}) + env_var.update({"MODEL": f"{model_path_prefix}"}) + if vllm_params: + params = f"{params} {vllm_params}" + env_var.update({"PARAMS": params}) + for env in container_spec.get(ContainerSpec.ENV_VARS, []): + if isinstance(env, dict): + env_var.update(env) # Start model deployment # configure model deployment infrastructure # todo : any other infrastructure params needed? diff --git a/ads/aqua/utils.py b/ads/aqua/utils.py index 7f32ba0bf..d0eca6f7c 100644 --- a/ads/aqua/utils.py +++ b/ads/aqua/utils.py @@ -522,6 +522,19 @@ def _build_job_identifier( return AquaResourceIdentifier() +def container_config_path(): + return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config" + + +def get_container_config(): + config = load_config( + file_path=container_config_path(), + config_file_name=CONTAINER_INDEX, + ) + + return config + + def get_container_image( config_file_name: str = None, container_type: str = None ) -> str: @@ -539,14 +552,8 @@ def get_container_image( A dict of allowed configs. """ - config_file_name = ( - f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config" - ) - - config = load_config( - file_path=config_file_name, - config_file_name=CONTAINER_INDEX, - ) + config = config_file_name or get_container_config() + config_file_name = container_config_path() if container_type not in config: raise AquaValueError( @@ -771,6 +778,7 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str: object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/" command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile}" try: + logger.info(f"Running: {command}") subprocess.check_call(shlex.split(command)) except subprocess.CalledProcessError as e: logger.error(