Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test cases for import models #822

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ads/aqua/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@ class FineTuningCustomMetadata(Enum):
VALIDATION_METRICS_FINAL = "validation_metrics_final"
TRINING_METRICS = "training_metrics"
VALIDATION_METRICS = "validation_metrics"

SERVICE_MANAGED_CONTAINER_URI_SCHEME = "dsmc://"
1 change: 0 additions & 1 deletion ads/aqua/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def create(
) # Give precendece to the input parameter

env_var.update({"BASE_MODEL": f"{model_path_prefix}"})
env_var.update({"MODEL": f"{model_path_prefix}"})
if vllm_params:
params = f"{params} {vllm_params}"
env_var.update({"PARAMS": params})
Expand Down
27 changes: 18 additions & 9 deletions ads/aqua/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_artifact_path,
read_file,
upload_folder,
is_service_managed_container,
)
from ads.common.auth import default_signer
from ads.common.object_storage_details import ObjectStorageDetails
Expand Down Expand Up @@ -802,7 +803,7 @@ def _create_model_catalog_entry(
inference_container_type_smc: bool,
finetuning_container_type_smc: bool,
shadow_model: DataScienceModel,
) -> str:
) -> DataScienceModel:
"""Create model by reference from the object storage path

Args:
Expand All @@ -815,7 +816,7 @@ def _create_model_catalog_entry(
shadow_model (DataScienceModel): If set, then copies all the tags and custom metadata information from the service shadow model

Returns:
str: Display name of th model (This should be model ocid)
DataScienceModel: Returns Datascience model
"""
model_info = None
model = DataScienceModel()
Expand Down Expand Up @@ -843,11 +844,15 @@ def _create_model_catalog_entry(
metadata = ModelCustomMetadata()
if not inference_container or not finetuning_container:
containers = self._fetch_defaults_for_hf_model(model_name=model_name)
if not inference_container and not containers.inference_container:
if not inference_container and (
not containers or not containers.inference_container
):
raise ValueError(
f"Require Inference container information. Model: {model_name} does not have associated inference container defaults. Check docs for more information on how to pass inference container"
)
if not finetuning_container and not containers.finetuning_container:
if not finetuning_container and (
not containers or not containers.finetuning_container
):
logger.warn(
f"Require Inference container information. Model: {model_name} does not have associated inference container defaults. Check docs for more information on how to pass inference container. Proceeding with model registration without the fine-tuning container information. This model will not be available for fine tuning."
)
Expand All @@ -856,7 +861,7 @@ def _create_model_catalog_entry(
metadata.add(
key=AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
value=inference_container or containers.inference_container,
description="Inference container mapping for SMC",
description=f"Inference container mapping for {model_name}",
category="Other",
)
# If SMC, the container information has to be looked up from container_index.json for the latest version
Expand All @@ -877,7 +882,7 @@ def _create_model_catalog_entry(
metadata.add(
key=AQUA_FINETUNING_CONTAINER_METADATA_NAME,
value=finetuning_container or containers.finetuning_container,
description="Fine-tuning container mapping for SMC",
description=f"Fine-tuning container mapping for {model_name}",
category="Other",
)
# If SMC, the container information has to be looked up from container_index.json for the latest version
Expand Down Expand Up @@ -918,7 +923,7 @@ def _create_model_catalog_entry(
.with_freeform_tags(**tags)
).create(model_by_reference=True)
logger.debug(model)
return model.display_name
return model

def register(
self,
Expand Down Expand Up @@ -1001,8 +1006,12 @@ def register(
model_name=model_name,
inference_container=inference_container,
finetuning_container=finetuning_container,
inference_container_type_smc=inference_container_type_smc,
finetuning_container_type_smc=finetuning_container_type_smc,
inference_container_type_smc=True
if is_service_managed_container(inference_container)
else inference_container_type_smc,
finetuning_container_type_smc=True
if is_service_managed_container(inference_container)
else finetuning_container_type_smc,
shadow_model=shadow_model_details,
)

Expand Down
9 changes: 8 additions & 1 deletion ads/aqua/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import oci
from oci.data_science.models import JobRun, Model

from ads.aqua.constants import RqsAdditionalDetails
from ads.aqua.constants import (
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
RqsAdditionalDetails,
)
from ads.aqua.data import AquaResourceIdentifier
from ads.aqua.exception import AquaFileNotFoundError, AquaRuntimeError, AquaValueError
from ads.common.auth import default_signer, AuthState
Expand Down Expand Up @@ -786,3 +789,7 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
)

return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path


def is_service_managed_container(container):
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
Loading