Skip to content

Commit

Permalink
test cases for import models
Browse files Browse the repository at this point in the history
  • Loading branch information
mayoor committed May 6, 2024
1 parent 61563ca commit 48e9045
Show file tree
Hide file tree
Showing 5 changed files with 501 additions and 142 deletions.
2 changes: 2 additions & 0 deletions ads/aqua/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@ class FineTuningCustomMetadata(Enum):
VALIDATION_METRICS_FINAL = "validation_metrics_final"
TRINING_METRICS = "training_metrics"
VALIDATION_METRICS = "validation_metrics"

SERVICE_MANAGED_CONTAINER_URI_SCHEME = "dsmc://"
1 change: 0 additions & 1 deletion ads/aqua/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def create(
) # Give precendece to the input parameter

env_var.update({"BASE_MODEL": f"{model_path_prefix}"})
env_var.update({"MODEL": f"{model_path_prefix}"})
if vllm_params:
params = f"{params} {vllm_params}"
env_var.update({"PARAMS": params})
Expand Down
27 changes: 18 additions & 9 deletions ads/aqua/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_artifact_path,
read_file,
upload_folder,
is_service_managed_container,
)
from ads.common.auth import default_signer
from ads.common.object_storage_details import ObjectStorageDetails
Expand Down Expand Up @@ -802,7 +803,7 @@ def _create_model_catalog_entry(
inference_container_type_smc: bool,
finetuning_container_type_smc: bool,
shadow_model: DataScienceModel,
) -> str:
) -> DataScienceModel:
"""Create model by reference from the object storage path
Args:
Expand All @@ -815,7 +816,7 @@ def _create_model_catalog_entry(
shadow_model (DataScienceModel): If set, then copies all the tags and custom metadata information from the service shadow model
Returns:
str: Display name of th model (This should be model ocid)
DataScienceModel: Returns Datascience model
"""
model_info = None
model = DataScienceModel()
Expand Down Expand Up @@ -843,11 +844,15 @@ def _create_model_catalog_entry(
metadata = ModelCustomMetadata()
if not inference_container or not finetuning_container:
containers = self._fetch_defaults_for_hf_model(model_name=model_name)
if not inference_container and not containers.inference_container:
if not inference_container and (
not containers or not containers.inference_container
):
raise ValueError(
f"Require Inference container information. Model: {model_name} does not have associated inference container defaults. Check docs for more information on how to pass inference container"
)
if not finetuning_container and not containers.finetuning_container:
if not finetuning_container and (
not containers or not containers.finetuning_container
):
logger.warn(
f"Require Inference container information. Model: {model_name} does not have associated inference container defaults. Check docs for more information on how to pass inference container. Proceeding with model registration without the fine-tuning container information. This model will not be available for fine tuning."
)
Expand All @@ -856,7 +861,7 @@ def _create_model_catalog_entry(
metadata.add(
key=AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
value=inference_container or containers.inference_container,
description="Inference container mapping for SMC",
description=f"Inference container mapping for {model_name}",
category="Other",
)
# If SMC, the container information has to be looked up from container_index.json for the latest version
Expand All @@ -877,7 +882,7 @@ def _create_model_catalog_entry(
metadata.add(
key=AQUA_FINETUNING_CONTAINER_METADATA_NAME,
value=finetuning_container or containers.finetuning_container,
description="Fine-tuning container mapping for SMC",
description=f"Fine-tuning container mapping for {model_name}",
category="Other",
)
# If SMC, the container information has to be looked up from container_index.json for the latest version
Expand Down Expand Up @@ -918,7 +923,7 @@ def _create_model_catalog_entry(
.with_freeform_tags(**tags)
).create(model_by_reference=True)
logger.debug(model)
return model.display_name
return model

def register(
self,
Expand Down Expand Up @@ -1001,8 +1006,12 @@ def register(
model_name=model_name,
inference_container=inference_container,
finetuning_container=finetuning_container,
inference_container_type_smc=inference_container_type_smc,
finetuning_container_type_smc=finetuning_container_type_smc,
inference_container_type_smc=True
if is_service_managed_container(inference_container)
else inference_container_type_smc,
finetuning_container_type_smc=True
if is_service_managed_container(inference_container)
else finetuning_container_type_smc,
shadow_model=shadow_model_details,
)

Expand Down
9 changes: 8 additions & 1 deletion ads/aqua/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import oci
from oci.data_science.models import JobRun, Model

from ads.aqua.constants import RqsAdditionalDetails
from ads.aqua.constants import (
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
RqsAdditionalDetails,
)
from ads.aqua.data import AquaResourceIdentifier
from ads.aqua.exception import AquaFileNotFoundError, AquaRuntimeError, AquaValueError
from ads.common.auth import default_signer, AuthState
Expand Down Expand Up @@ -786,3 +789,7 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
)

return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path


def is_service_managed_container(container):
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)

0 comments on commit 48e9045

Please sign in to comment.