Skip to content
Merged
29 changes: 29 additions & 0 deletions ads/aqua/extension/deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,33 @@ def get(self, id=""):
else:
raise HTTPError(400, f"The request {self.request.path} is invalid.")

@handle_exceptions
def delete(self, model_deployment_id):
return self.finish(AquaDeploymentApp().delete(model_deployment_id))

@handle_exceptions
def put(self, *args, **kwargs):
"""
Handles put request for the activating and deactivating OCI datascience model deployments
Raises
------
HTTPError
Raises HTTPError if inputs are missing or are invalid
"""
url_parse = urlparse(self.request.path)
paths = url_parse.path.strip("/").split("/")
if len(paths) != 4 or paths[0] != "aqua" or paths[1] != "deployments":
raise HTTPError(400, f"The request {self.request.path} is invalid.")

model_deployment_id = paths[2]
action = paths[3]
if action == "activate":
return self.finish(AquaDeploymentApp().activate(model_deployment_id))
elif action == "deactivate":
return self.finish(AquaDeploymentApp().deactivate(model_deployment_id))
else:
raise HTTPError(400, f"The request {self.request.path} is invalid.")

@handle_exceptions
def post(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -264,5 +291,7 @@ def post(self, *args, **kwargs):
("deployments/?([^/]*)/params", AquaDeploymentParamsHandler),
("deployments/config/?([^/]*)", AquaDeploymentHandler),
("deployments/?([^/]*)", AquaDeploymentHandler),
("deployments/?([^/]*)/activate", AquaDeploymentHandler),
("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),
("inference", AquaDeploymentInferenceHandler),
]
1 change: 1 addition & 0 deletions ads/aqua/extension/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ class Errors(str):
NO_INPUT_DATA = "No input data provided."
MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required."
INVALID_VALUE_OF_PARAMETER = "Invalid value of parameter: '{}'"
41 changes: 39 additions & 2 deletions ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@

from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
from ads.aqua.common.utils import get_hf_model_info, list_hf_models
from ads.aqua.common.utils import (
get_container_config,
get_hf_model_info,
list_hf_models,
)
from ads.aqua.extension.base_handler import AquaAPIhandler
from ads.aqua.extension.errors import Errors
from ads.aqua.model import AquaModelApp
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
from ads.aqua.ui import ModelFormat
from ads.aqua.ui import AquaContainerConfig, ModelFormat


class AquaModelHandler(AquaAPIhandler):
Expand Down Expand Up @@ -73,6 +77,8 @@ def delete(self, id=""):
paths = url_parse.path.strip("/")
if paths.startswith("aqua/model/cache"):
return self.finish(AquaModelApp().clear_model_list_cache())
elif id:
return self.finish(AquaModelApp().delete_model(id))
else:
raise HTTPError(400, f"The request {self.request.path} is invalid.")

Expand Down Expand Up @@ -137,6 +143,37 @@ def post(self, *args, **kwargs):
)
)

@handle_exceptions
def put(self, id):
try:
input_data = self.get_json_body()
except Exception as ex:
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex

if not input_data:
raise HTTPError(400, Errors.NO_INPUT_DATA)

inference_container = input_data.get("inference_container")
containers = list(
AquaContainerConfig.from_container_index_json(
config=get_container_config(), enable_spec=True
).inference.values()
)
family_values = [item.family for item in containers]

if inference_container is not None and inference_container not in family_values:
raise HTTPError(
400, Errors.INVALID_VALUE_OF_PARAMETER.format("inference_container")
)

enable_finetuning = input_data.get("enable_finetuning")
task = input_data.get("task")
return self.finish(
AquaModelApp().edit_registered_model(
id, inference_container, enable_finetuning, task
)
)


class AquaModelLicenseHandler(AquaAPIhandler):
"""Handler for Aqua Model license REST APIs."""
Expand Down
143 changes: 116 additions & 27 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@
import oci
from cachetools import TTLCache
from huggingface_hub import snapshot_download
from oci.data_science.models import JobRun, Model
from oci.data_science.models import JobRun, Metadata, Model, UpdateModelDetails

from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
from ads.aqua.app import AquaApp
from ads.aqua.common.enums import InferenceContainerTypeFamily, Tags
from ads.aqua.common.enums import (
FineTuningContainerTypeFamily,
InferenceContainerTypeFamily,
Tags,
)
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
from ads.aqua.common.utils import (
LifecycleStatus,
Expand Down Expand Up @@ -75,7 +79,11 @@
TENANCY_OCID,
)
from ads.model import DataScienceModel
from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem
from ads.model.model_metadata import (
MetadataCustomCategory,
ModelCustomMetadata,
ModelCustomMetadataItem,
)
from ads.telemetry import telemetry


Expand Down Expand Up @@ -323,6 +331,97 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod

return model_details

@telemetry(entry_point="plugin=model&action=delete", name="aqua")
def delete_model(self, model_id):
ds_model = DataScienceModel.from_id(model_id)
is_registered_model = ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None)
is_fine_tuned_model = ds_model.freeform_tags.get(
Tags.AQUA_FINE_TUNED_MODEL_TAG, None
)
if is_registered_model or is_fine_tuned_model:
return ds_model.delete()
else:
raise AquaRuntimeError(
f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted."
)

@telemetry(entry_point="plugin=model&action=delete", name="aqua")
def edit_registered_model(self, id, inference_container, enable_finetuning, task):
"""Edits the default config of unverified registered model.

Parameters
----------
id: str
The model OCID.
inference_container: str.
The inference container family name
enable_finetuning: str
Flag to enable or disable finetuning over the model. Defaults to None
task:
The usecase type of the model. e.g , text-generation , text_embedding etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add task to docstrings parameters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Returns
-------
Model:
The instance of oci.data_science.models.Model.

"""
ds_model = DataScienceModel.from_id(id)
if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None):
if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None):
raise AquaRuntimeError(
f"Failed to edit model:{id}. Only registered unverified models can be edited."
)
else:
custom_metadata_list = ds_model.custom_metadata_list
freeform_tags = ds_model.freeform_tags
if inference_container:
custom_metadata_list.add(
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
value=inference_container,
category=MetadataCustomCategory.OTHER,
description="Deployment container mapping for SMC",
replace=True,
)
if enable_finetuning is not None:
if enable_finetuning.lower() == "true":
custom_metadata_list.add(
key=ModelCustomMetadataFields.FINETUNE_CONTAINER,
value=FineTuningContainerTypeFamily.AQUA_FINETUNING_CONTAINER_FAMILY,
category=MetadataCustomCategory.OTHER,
description="Fine-tuning container mapping for SMC",
replace=True,
)
freeform_tags.update({Tags.READY_TO_FINE_TUNE: "true"})
elif enable_finetuning.lower() == "false":
try:
custom_metadata_list.remove(
ModelCustomMetadataFields.FINETUNE_CONTAINER
)
freeform_tags.pop(Tags.READY_TO_FINE_TUNE)
except Exception as ex:
raise AquaRuntimeError(
f"The given model already doesn't support finetuning: {ex}"
)

custom_metadata_list.remove("modelDescription")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modelDescription is required for model by reference - can you comment here why this metadata is removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modelDescription key had boolean value instead of string in custom metadata. editing model with modelDescription key in custom metadata required us to convert boolean to string which i though is additional not so useful step. Removing this key won't have any affect since as per MBR , this key is immutable and won't actually be removed or deleted.

if task:
freeform_tags.update({Tags.TASK: task})

updated_custom_metadata_list = [
Metadata(**metadata)
for metadata in custom_metadata_list.to_dict()["data"]
]
update_model_details = UpdateModelDetails(
custom_metadata_list=updated_custom_metadata_list,
freeform_tags=freeform_tags,
)
return AquaApp().update_model(id, update_model_details).data
else:
raise AquaRuntimeError(
f"Failed to edit model:{id}. Only registered unverified models can be edited."
)

def _fetch_metric_from_metadata(
self,
custom_metadata_list: ModelCustomMetadata,
Expand Down Expand Up @@ -935,38 +1034,39 @@ def _validate_model(
# gguf extension exist.
if {ModelFormat.SAFETENSORS, ModelFormat.GGUF}.issubset(set(model_formats)):
if (
import_model_details.inference_container.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
import_model_details.inference_container.lower()
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
):
self._validate_gguf_format(
import_model_details=import_model_details,
verified_model=verified_model,
gguf_model_files=gguf_model_files,
validation_result=validation_result,
model_name=model_name
model_name=model_name,
)
else:
self._validate_safetensor_format(
import_model_details=import_model_details,
verified_model=verified_model,
validation_result=validation_result,
hf_download_config_present=hf_download_config_present,
model_name=model_name
model_name=model_name,
)
elif ModelFormat.SAFETENSORS in model_formats:
self._validate_safetensor_format(
import_model_details=import_model_details,
verified_model=verified_model,
validation_result=validation_result,
hf_download_config_present=hf_download_config_present,
model_name=model_name
model_name=model_name,
)
elif ModelFormat.GGUF in model_formats:
self._validate_gguf_format(
import_model_details=import_model_details,
verified_model=verified_model,
gguf_model_files=gguf_model_files,
validation_result=validation_result,
model_name=model_name
model_name=model_name,
)

return validation_result
Expand All @@ -977,7 +1077,7 @@ def _validate_safetensor_format(
verified_model: DataScienceModel = None,
validation_result: ModelValidationResult = None,
hf_download_config_present: bool = None,
model_name: str = None
model_name: str = None,
):
if import_model_details.download_from_hf:
# validates config.json exists for safetensors model from hugginface
Expand All @@ -1004,20 +1104,13 @@ def _validate_safetensor_format(
) from ex
else:
try:
metadata_model_type = (
verified_model.custom_metadata_list.get(
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
).value
)
metadata_model_type = verified_model.custom_metadata_list.get(
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
).value
if metadata_model_type:
if (
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
in model_config
):
if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
if (
model_config[
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
]
model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]
!= metadata_model_type
):
raise AquaRuntimeError(
Expand All @@ -1035,9 +1128,7 @@ def _validate_safetensor_format(
except Exception:
pass
if verified_model:
validation_result.telemetry_model_name = (
verified_model.display_name
)
validation_result.telemetry_model_name = verified_model.display_name
elif (
model_config is not None
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
Expand All @@ -1049,9 +1140,7 @@ def _validate_safetensor_format(
):
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
else:
validation_result.telemetry_model_name = (
AQUA_MODEL_TYPE_CUSTOM
)
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM

@staticmethod
def _validate_gguf_format(
Expand Down
12 changes: 12 additions & 0 deletions ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,18 @@ def list(self, **kwargs) -> List["AquaDeployment"]:

return results

@telemetry(entry_point="plugin=deployment&action=delete", name="aqua")
def delete(self,model_deployment_id:str):
return self.ds_client.delete_model_deployment(model_deployment_id=model_deployment_id).data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asking for my understanding - is this an async process or will this wait till MD is deleted? Same comment for activating and deactivating model.
For evaluation delete/cancel - we had to add additional wrapper to handle the wait time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be async only just like delete/cancel evals.


@telemetry(entry_point="plugin=deployment&action=deactivate",name="aqua")
def deactivate(self,model_deployment_id:str):
return self.ds_client.deactivate_model_deployment(model_deployment_id=model_deployment_id).data

@telemetry(entry_point="plugin=deployment&action=activate",name="aqua")
def activate(self,model_deployment_id:str):
return self.ds_client.activate_model_deployment(model_deployment_id=model_deployment_id).data

@telemetry(entry_point="plugin=deployment&action=get", name="aqua")
def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
"""Gets the information of Aqua model deployment.
Expand Down
24 changes: 24 additions & 0 deletions tests/unitary/with_extras/aqua/test_deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,30 @@ def test_get_deployment(self, mock_get):
self.deployment_handler.get(id="mock-model-id")
mock_get.assert_called()

@patch("ads.aqua.modeldeployment.AquaDeploymentApp.delete")
def test_delete_deployment(self, mock_delete):
self.deployment_handler.request.path = "aqua/deployments"
self.deployment_handler.delete("mock-model-id")
mock_delete.assert_called()

@patch("ads.aqua.modeldeployment.AquaDeploymentApp.activate")
def test_activate_deployment(self, mock_activate):
self.deployment_handler.request.path = (
"aqua/deployments/ocid1.datasciencemodeldeployment.oc1.iad.xxx/activate"
)
mock_activate.return_value = {"lifecycle_state": "UPDATING"}
self.deployment_handler.put()
mock_activate.assert_called()

@patch("ads.aqua.modeldeployment.AquaDeploymentApp.deactivate")
def test_deactivate_deployment(self, mock_deactivate):
self.deployment_handler.request.path = (
"aqua/deployments/ocid1.datasciencemodeldeployment.oc1.iad.xxx/deactivate"
)
mock_deactivate.return_value = {"lifecycle_state": "UPDATING"}
self.deployment_handler.put()
mock_deactivate.assert_called()

@patch("ads.aqua.modeldeployment.AquaDeploymentApp.list")
def test_list_deployment(self, mock_list):
"""Test get method to return a list of model deployments."""
Expand Down
Loading