Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BYOM Model support #812

Merged
merged 7 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions THIRD_PARTY_LICENSES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ htmllistparse
* Source code: https://github.com/gumblex/htmllisting-parser
* Project home: https://github.com/gumblex/htmllisting-parser

huggingface_hub
* Copyright 2023-present, the HuggingFace Inc. team.
* License: Apache-2.0 license
* Source code: https://github.com/huggingface/huggingface_hub
* Project home: https://github.com/huggingface/huggingface_hub

ibisframework

* Copyright 2015 Cloudera Inc.
Expand Down
1 change: 1 addition & 0 deletions ads/aqua/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ class Tags(Enum):
AQUA_EVALUATION = "aqua_evaluation"
AQUA_FINE_TUNING = "aqua_finetuning"
READY_TO_FINE_TUNE = "ready_to_fine_tune"
BASE_MODEL_CUSTOM = "aqua_custom_base_model"
73 changes: 58 additions & 15 deletions ads/aqua/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from enum import Enum
import json
import logging
from dataclasses import dataclass, field, asdict
Expand All @@ -17,6 +18,7 @@
from ads.aqua.utils import (
UNKNOWN,
MODEL_BY_REFERENCE_OSS_PATH_KEY,
get_container_config,
load_config,
get_container_image,
UNKNOWN_DICT,
Expand All @@ -38,17 +40,29 @@
)
from ads.common.serializer import DataClassSerializable
from ads.config import (
AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME,
AQUA_MODEL_DEPLOYMENT_CONFIG,
COMPARTMENT_OCID,
AQUA_CONFIG_FOLDER,
AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS,
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
AQUA_SERVED_MODEL_NAME,
)
from ads.common.object_storage_details import ObjectStorageDetails
from ads.telemetry import telemetry


class ContainerSpec:
"""
Class to hold to hold keys within the container spec.
"""

CONTAINER_SPEC = "containerSpec"
CLI_PARM = "cliParam"
SERVER_PORT = "serverPort"
HEALTH_CHECK_PORT = "healthCheckPort"
ENV_VARS = "envVars"


@dataclass
class ShapeInfo:
instance_shape: str = None
Expand Down Expand Up @@ -196,8 +210,8 @@ def create(
description: str = None,
bandwidth_mbps: int = None,
web_concurrency: int = None,
server_port: int = 8080,
health_check_port: int = 8080,
server_port: int = None,
health_check_port: int = None,
env_var: Dict = None,
) -> "AquaDeployment":
"""
Expand Down Expand Up @@ -231,9 +245,9 @@ def create(
The number of worker processes/threads to handle incoming requests
with_bucket_uri(bucket_uri)
Sets the bucket uri when uploading large size model.
server_port: (int). Defaults to 8080.
server_port: (int).
The server port for docker container image.
health_check_port: (int). Defaults to 8080.
health_check_port: (int).
The health check port for docker container image.
env_var : dict, optional
Environment variable for the deployment, by default None.
Expand All @@ -243,6 +257,7 @@ def create(
An Aqua deployment instance

"""
# TODO validate if the service model has no artifact and if it requires import step before deployment.
# Create a model catalog entry in the user compartment
aqua_model = AquaModelApp().create(
model_id=model_id, compartment_id=compartment_id, project_id=project_id
Expand Down Expand Up @@ -307,14 +322,6 @@ def create(
os_path = ObjectStorageDetails.from_path(model_path_prefix)
model_path_prefix = os_path.filepath.rstrip("/")

env_var.update({"BASE_MODEL": f"{model_path_prefix}"})
params = f"--served-model-name {AQUA_SERVED_MODEL_NAME} --seed 42 "
if vllm_params:
params += vllm_params
env_var.update({"PARAMS": params})
env_var.update({"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions"})
env_var.update({"MODEL_DEPLOY_ENABLE_STREAMING": "true"})

if is_fine_tuned_model:
_, fine_tune_output_path = get_model_by_reference_paths(
aqua_model.model_file_description
Expand All @@ -332,6 +339,7 @@ def create(

logging.info(f"Env vars used for deploying {aqua_model.id} :{env_var}")

is_custom_container = False
try:
container_type_key = aqua_model.custom_metadata_list.get(
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
Expand All @@ -340,15 +348,50 @@ def create(
raise AquaValueError(
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field for model {aqua_model.id}"
)
try:
# Check if the container override flag is set. If set, then the user has chosen custom image
if aqua_model.custom_metadata_list.get(
AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME
).value:
is_custom_container = True
except Exception:
pass

# fetch image name from config
container_image = get_container_image(
container_type=container_type_key,
# If the image is of type custom, then `container_type_key` is the inference image
container_image = (
get_container_image(
container_type=container_type_key,
)
if not is_custom_container
else container_type_key
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clarifying: for SMC container, container_type_key is odsc-vllm-serving, whereas for byoc container, it will be something like <region>.ocir.io/<namespace>/user-provided-container:1.0.0.0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is right

)
logging.info(
f"Aqua Image used for deploying {aqua_model.id} : {container_image}"
)

# Fetch the startup cli command for the container
# container_index.json will have "containerSpec" section which will provide the cli params for a given container family
container_config = get_container_config()
container_spec = container_config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
container_type_key, {}
)
params = container_spec.get(ContainerSpec.CLI_PARM, "")
server_port = server_port or container_spec.get(
ContainerSpec.SERVER_PORT
) # Give precendece to the input parameter
health_check_port = health_check_port or container_spec.get(
ContainerSpec.HEALTH_CHECK_PORT
) # Give precendece to the input parameter

env_var.update({"BASE_MODEL": f"{model_path_prefix}"})
env_var.update({"MODEL": f"{model_path_prefix}"})
if vllm_params:
params = f"{params} {vllm_params}"
env_var.update({"PARAMS": params})
for env in container_spec.get(ContainerSpec.ENV_VARS, []):
if isinstance(env, dict):
env_var.update(env)
# Start model deployment
# configure model deployment infrastructure
# todo : any other infrastructure params needed?
Expand Down
20 changes: 18 additions & 2 deletions ads/aqua/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ads.common.serializer import DataClassSerializable
from ads.common.utils import get_console_link
from ads.config import (
AQUA_FINETUNING_CONTAINER_OVERRIDE_FLAG_METADATA_NAME,
AQUA_JOB_SUBNET_ID,
AQUA_MODEL_FINETUNING_CONFIG,
COMPARTMENT_OCID,
Expand Down Expand Up @@ -383,6 +384,15 @@ def create(
ft_container = source.custom_metadata_list.get(
FineTuneCustomMetadata.SERVICE_MODEL_FINE_TUNE_CONTAINER.value
).value
is_custom_container = False
try:
# Check if the container override flag is set. If set, then the user has chosen custom image
if source.custom_metadata_list.get(
AQUA_FINETUNING_CONTAINER_OVERRIDE_FLAG_METADATA_NAME
).value:
is_custom_container = True
except Exception:
pass

batch_size = (
ft_config.get("shape", UNKNOWN_DICT)
Expand All @@ -406,6 +416,7 @@ def create(
),
parameters=ft_parameters,
ft_container=ft_container,
is_custom_container=is_custom_container,
)
).create()
logger.debug(
Expand Down Expand Up @@ -544,10 +555,15 @@ def _build_fine_tuning_runtime(
parameters: AquaFineTuningParams,
ft_container: str = None,
finetuning_params: str = None,
is_custom_container: bool = False,
) -> Runtime:
"""Builds fine tuning runtime for Job."""
container = get_container_image(
container_type=ft_container,
container = (
get_container_image(
container_type=ft_container,
)
if not is_custom_container
else ft_container
)
runtime = (
ContainerRuntime()
Expand Down
Loading
Loading