diff --git a/ads/aqua/extension/deployment_handler.py b/ads/aqua/extension/deployment_handler.py index 3e3a6c277..72df4200f 100644 --- a/ads/aqua/extension/deployment_handler.py +++ b/ads/aqua/extension/deployment_handler.py @@ -54,6 +54,33 @@ def get(self, id=""): else: raise HTTPError(400, f"The request {self.request.path} is invalid.") + @handle_exceptions + def delete(self, model_deployment_id): + return self.finish(AquaDeploymentApp().delete(model_deployment_id)) + + @handle_exceptions + def put(self, *args, **kwargs): + """ + Handles put request for the activating and deactivating OCI datascience model deployments + Raises + ------ + HTTPError + Raises HTTPError if inputs are missing or are invalid + """ + url_parse = urlparse(self.request.path) + paths = url_parse.path.strip("/").split("/") + if len(paths) != 4 or paths[0] != "aqua" or paths[1] != "deployments": + raise HTTPError(400, f"The request {self.request.path} is invalid.") + + model_deployment_id = paths[2] + action = paths[3] + if action == "activate": + return self.finish(AquaDeploymentApp().activate(model_deployment_id)) + elif action == "deactivate": + return self.finish(AquaDeploymentApp().deactivate(model_deployment_id)) + else: + raise HTTPError(400, f"The request {self.request.path} is invalid.") + @handle_exceptions def post(self, *args, **kwargs): """ @@ -264,5 +291,7 @@ def post(self, *args, **kwargs): ("deployments/?([^/]*)/params", AquaDeploymentParamsHandler), ("deployments/config/?([^/]*)", AquaDeploymentHandler), ("deployments/?([^/]*)", AquaDeploymentHandler), + ("deployments/?([^/]*)/activate", AquaDeploymentHandler), + ("deployments/?([^/]*)/deactivate", AquaDeploymentHandler), ("inference", AquaDeploymentInferenceHandler), ] diff --git a/ads/aqua/extension/errors.py b/ads/aqua/extension/errors.py index d5e44944c..9829ff9e4 100644 --- a/ads/aqua/extension/errors.py +++ b/ads/aqua/extension/errors.py @@ -8,3 +8,4 @@ class Errors(str): NO_INPUT_DATA = "No input data provided." MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'" MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required." + INVALID_VALUE_OF_PARAMETER = "Invalid value of parameter: '{}'" diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index 5fa25992f..279ea412b 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -9,12 +9,16 @@ from ads.aqua.common.decorator import handle_exceptions from ads.aqua.common.errors import AquaRuntimeError, AquaValueError -from ads.aqua.common.utils import get_hf_model_info, list_hf_models +from ads.aqua.common.utils import ( + get_container_config, + get_hf_model_info, + list_hf_models, +) from ads.aqua.extension.base_handler import AquaAPIhandler from ads.aqua.extension.errors import Errors from ads.aqua.model import AquaModelApp from ads.aqua.model.entities import AquaModelSummary, HFModelSummary -from ads.aqua.ui import ModelFormat +from ads.aqua.ui import AquaContainerConfig, ModelFormat class AquaModelHandler(AquaAPIhandler): @@ -73,6 +77,8 @@ def delete(self, id=""): paths = url_parse.path.strip("/") if paths.startswith("aqua/model/cache"): return self.finish(AquaModelApp().clear_model_list_cache()) + elif id: + return self.finish(AquaModelApp().delete_model(id)) else: raise HTTPError(400, f"The request {self.request.path} is invalid.") @@ -137,6 +143,37 @@ def post(self, *args, **kwargs): ) ) + @handle_exceptions + def put(self, id): + try: + input_data = self.get_json_body() + except Exception as ex: + raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex + + if not input_data: + raise HTTPError(400, Errors.NO_INPUT_DATA) + + inference_container = input_data.get("inference_container") + containers = list( + AquaContainerConfig.from_container_index_json( + config=get_container_config(), enable_spec=True + ).inference.values() + ) + family_values = [item.family for item in containers] + + if inference_container is not None and inference_container not in family_values: + raise HTTPError( + 400, Errors.INVALID_VALUE_OF_PARAMETER.format("inference_container") + ) + + enable_finetuning = input_data.get("enable_finetuning") + task = input_data.get("task") + return self.finish( + AquaModelApp().edit_registered_model( + id, inference_container, enable_finetuning, task + ) + ) + class AquaModelLicenseHandler(AquaAPIhandler): """Handler for Aqua Model license REST APIs.""" diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index b27d28c49..d07c4ecab 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -10,11 +10,15 @@ import oci from cachetools import TTLCache from huggingface_hub import snapshot_download -from oci.data_science.models import JobRun, Model +from oci.data_science.models import JobRun, Metadata, Model, UpdateModelDetails from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger from ads.aqua.app import AquaApp -from ads.aqua.common.enums import InferenceContainerTypeFamily, Tags +from ads.aqua.common.enums import ( + FineTuningContainerTypeFamily, + InferenceContainerTypeFamily, + Tags, +) from ads.aqua.common.errors import AquaRuntimeError, AquaValueError from ads.aqua.common.utils import ( LifecycleStatus, @@ -75,7 +79,11 @@ TENANCY_OCID, ) from ads.model import DataScienceModel -from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem +from ads.model.model_metadata import ( + MetadataCustomCategory, + ModelCustomMetadata, + ModelCustomMetadataItem, +) from ads.telemetry import telemetry @@ -323,6 +331,97 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod return model_details + @telemetry(entry_point="plugin=model&action=delete", name="aqua") + def delete_model(self, model_id): + ds_model = DataScienceModel.from_id(model_id) + is_registered_model = ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) + is_fine_tuned_model = ds_model.freeform_tags.get( + Tags.AQUA_FINE_TUNED_MODEL_TAG, None + ) + if is_registered_model or is_fine_tuned_model: + return ds_model.delete() + else: + raise AquaRuntimeError( + f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted." + ) + + @telemetry(entry_point="plugin=model&action=delete", name="aqua") + def edit_registered_model(self, id, inference_container, enable_finetuning, task): + """Edits the default config of unverified registered model. + + Parameters + ---------- + id: str + The model OCID. + inference_container: str. + The inference container family name + enable_finetuning: str + Flag to enable or disable finetuning over the model. Defaults to None + task: + The usecase type of the model. e.g , text-generation , text_embedding etc. + + Returns + ------- + Model: + The instance of oci.data_science.models.Model. + + """ + ds_model = DataScienceModel.from_id(id) + if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None): + if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None): + raise AquaRuntimeError( + f"Failed to edit model:{id}. Only registered unverified models can be edited." + ) + else: + custom_metadata_list = ds_model.custom_metadata_list + freeform_tags = ds_model.freeform_tags + if inference_container: + custom_metadata_list.add( + key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, + value=inference_container, + category=MetadataCustomCategory.OTHER, + description="Deployment container mapping for SMC", + replace=True, + ) + if enable_finetuning is not None: + if enable_finetuning.lower() == "true": + custom_metadata_list.add( + key=ModelCustomMetadataFields.FINETUNE_CONTAINER, + value=FineTuningContainerTypeFamily.AQUA_FINETUNING_CONTAINER_FAMILY, + category=MetadataCustomCategory.OTHER, + description="Fine-tuning container mapping for SMC", + replace=True, + ) + freeform_tags.update({Tags.READY_TO_FINE_TUNE: "true"}) + elif enable_finetuning.lower() == "false": + try: + custom_metadata_list.remove( + ModelCustomMetadataFields.FINETUNE_CONTAINER + ) + freeform_tags.pop(Tags.READY_TO_FINE_TUNE) + except Exception as ex: + raise AquaRuntimeError( + f"The given model already doesn't support finetuning: {ex}" + ) + + custom_metadata_list.remove("modelDescription") + if task: + freeform_tags.update({Tags.TASK: task}) + + updated_custom_metadata_list = [ + Metadata(**metadata) + for metadata in custom_metadata_list.to_dict()["data"] + ] + update_model_details = UpdateModelDetails( + custom_metadata_list=updated_custom_metadata_list, + freeform_tags=freeform_tags, + ) + return AquaApp().update_model(id, update_model_details).data + else: + raise AquaRuntimeError( + f"Failed to edit model:{id}. Only registered unverified models can be edited." + ) + def _fetch_metric_from_metadata( self, custom_metadata_list: ModelCustomMetadata, @@ -935,14 +1034,15 @@ def _validate_model( # gguf extension exist. if {ModelFormat.SAFETENSORS, ModelFormat.GGUF}.issubset(set(model_formats)): if ( - import_model_details.inference_container.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY + import_model_details.inference_container.lower() + == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY ): self._validate_gguf_format( import_model_details=import_model_details, verified_model=verified_model, gguf_model_files=gguf_model_files, validation_result=validation_result, - model_name=model_name + model_name=model_name, ) else: self._validate_safetensor_format( @@ -950,7 +1050,7 @@ def _validate_model( verified_model=verified_model, validation_result=validation_result, hf_download_config_present=hf_download_config_present, - model_name=model_name + model_name=model_name, ) elif ModelFormat.SAFETENSORS in model_formats: self._validate_safetensor_format( @@ -958,7 +1058,7 @@ def _validate_model( verified_model=verified_model, validation_result=validation_result, hf_download_config_present=hf_download_config_present, - model_name=model_name + model_name=model_name, ) elif ModelFormat.GGUF in model_formats: self._validate_gguf_format( @@ -966,7 +1066,7 @@ def _validate_model( verified_model=verified_model, gguf_model_files=gguf_model_files, validation_result=validation_result, - model_name=model_name + model_name=model_name, ) return validation_result @@ -977,7 +1077,7 @@ def _validate_safetensor_format( verified_model: DataScienceModel = None, validation_result: ModelValidationResult = None, hf_download_config_present: bool = None, - model_name: str = None + model_name: str = None, ): if import_model_details.download_from_hf: # validates config.json exists for safetensors model from hugginface @@ -1004,20 +1104,13 @@ def _validate_safetensor_format( ) from ex else: try: - metadata_model_type = ( - verified_model.custom_metadata_list.get( - AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE - ).value - ) + metadata_model_type = verified_model.custom_metadata_list.get( + AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE + ).value if metadata_model_type: - if ( - AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE - in model_config - ): + if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config: if ( - model_config[ - AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE - ] + model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE] != metadata_model_type ): raise AquaRuntimeError( @@ -1035,9 +1128,7 @@ def _validate_safetensor_format( except Exception: pass if verified_model: - validation_result.telemetry_model_name = ( - verified_model.display_name - ) + validation_result.telemetry_model_name = verified_model.display_name elif ( model_config is not None and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config @@ -1049,9 +1140,7 @@ def _validate_safetensor_format( ): validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}" else: - validation_result.telemetry_model_name = ( - AQUA_MODEL_TYPE_CUSTOM - ) + validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM @staticmethod def _validate_gguf_format( diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 654e00dc8..3fdceac04 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -485,6 +485,18 @@ def list(self, **kwargs) -> List["AquaDeployment"]: return results + @telemetry(entry_point="plugin=deployment&action=delete", name="aqua") + def delete(self,model_deployment_id:str): + return self.ds_client.delete_model_deployment(model_deployment_id=model_deployment_id).data + + @telemetry(entry_point="plugin=deployment&action=deactivate",name="aqua") + def deactivate(self,model_deployment_id:str): + return self.ds_client.deactivate_model_deployment(model_deployment_id=model_deployment_id).data + + @telemetry(entry_point="plugin=deployment&action=activate",name="aqua") + def activate(self,model_deployment_id:str): + return self.ds_client.activate_model_deployment(model_deployment_id=model_deployment_id).data + @telemetry(entry_point="plugin=deployment&action=get", name="aqua") def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": """Gets the information of Aqua model deployment. diff --git a/tests/unitary/with_extras/aqua/test_deployment_handler.py b/tests/unitary/with_extras/aqua/test_deployment_handler.py index 54756545a..a3c843113 100644 --- a/tests/unitary/with_extras/aqua/test_deployment_handler.py +++ b/tests/unitary/with_extras/aqua/test_deployment_handler.py @@ -92,6 +92,30 @@ def test_get_deployment(self, mock_get): self.deployment_handler.get(id="mock-model-id") mock_get.assert_called() + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.delete") + def test_delete_deployment(self, mock_delete): + self.deployment_handler.request.path = "aqua/deployments" + self.deployment_handler.delete("mock-model-id") + mock_delete.assert_called() + + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.activate") + def test_activate_deployment(self, mock_activate): + self.deployment_handler.request.path = ( + "aqua/deployments/ocid1.datasciencemodeldeployment.oc1.iad.xxx/activate" + ) + mock_activate.return_value = {"lifecycle_state": "UPDATING"} + self.deployment_handler.put() + mock_activate.assert_called() + + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.deactivate") + def test_deactivate_deployment(self, mock_deactivate): + self.deployment_handler.request.path = ( + "aqua/deployments/ocid1.datasciencemodeldeployment.oc1.iad.xxx/deactivate" + ) + mock_deactivate.return_value = {"lifecycle_state": "UPDATING"} + self.deployment_handler.put() + mock_deactivate.assert_called() + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.list") def test_list_deployment(self, mock_list): """Test get method to return a list of model deployments.""" diff --git a/tests/unitary/with_extras/aqua/test_model_handler.py b/tests/unitary/with_extras/aqua/test_model_handler.py index cb7a27080..db5475798 100644 --- a/tests/unitary/with_extras/aqua/test_model_handler.py +++ b/tests/unitary/with_extras/aqua/test_model_handler.py @@ -19,8 +19,8 @@ AquaModelLicenseHandler, ) from ads.aqua.model import AquaModelApp -from ads.aqua.model.constants import ModelTask from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary +from ads.aqua.ui import AquaContainerConfig class ModelHandlerTestCase(TestCase): @@ -79,6 +79,49 @@ def test_delete(self, mock_urlparse, mock_clear_model_list_cache): mock_urlparse.assert_called() mock_clear_model_list_cache.assert_called() + @patch("ads.aqua.extension.model_handler.urlparse") + @patch.object(AquaModelApp, "delete_model") + def test_delete_with_id(self, mock_delete, mock_urlparse): + request_path = MagicMock(path="aqua/model/ocid1.datasciencemodel.oc1.iad.xxx") + mock_urlparse.return_value = request_path + mock_delete.return_value = {"state": "DELETED"} + with patch( + "ads.aqua.extension.base_handler.AquaAPIhandler.finish" + ) as mock_finish: + mock_finish.side_effect = lambda x: x + result = self.model_handler.delete(id="ocid1.datasciencemodel.oc1.iad.xxx") + assert result["state"] is "DELETED" + mock_urlparse.assert_called() + mock_delete.assert_called() + + @patch.object(AquaContainerConfig, "from_container_index_json") + @patch.object(AquaModelApp, "edit_registered_model") + def test_put(self, mock_edit, mock_container_index): + mock_edit.return_value = {"state": "EDITED"} + mock_inference = MagicMock() + mock_inference.values.return_value = [ + MagicMock(family="odsc-vllm-serving"), + MagicMock(family="odsc-tgi-serving"), + MagicMock(family="odsc-vllm-serving"), + ] + + mock_container_index.return_value = MagicMock(inference=mock_inference) + self.model_handler.get_json_body = MagicMock( + return_value=dict( + task="text_generation", + enable_finetuning="true", + inference_container="odsc-tgi-serving", + ) + ) + with patch( + "ads.aqua.extension.base_handler.AquaAPIhandler.finish" + ) as mock_finish: + mock_finish.side_effect = lambda x: x + result = self.model_handler.put(id="ocid1.datasciencemodel.oc1.iad.xxx") + print(f"result: ", result) + assert result["state"] is "EDITED" + mock_edit.assert_called() + @patch.object(AquaModelApp, "list") def test_list(self, mock_list): with patch(