From 48e9045d4f6aa60b70d52863a0d27bad3e7814cd Mon Sep 17 00:00:00 2001 From: Mayoor Rao Date: Mon, 6 May 2024 14:24:19 -0700 Subject: [PATCH] test cases for import models --- ads/aqua/constants.py | 2 + ads/aqua/deployment.py | 1 - ads/aqua/model.py | 27 +- ads/aqua/utils.py | 9 +- tests/unitary/with_extras/aqua/test_model.py | 604 +++++++++++++++---- 5 files changed, 501 insertions(+), 142 deletions(-) diff --git a/ads/aqua/constants.py b/ads/aqua/constants.py index 0d4ea4f78..c04d8f3e1 100644 --- a/ads/aqua/constants.py +++ b/ads/aqua/constants.py @@ -43,3 +43,5 @@ class FineTuningCustomMetadata(Enum): VALIDATION_METRICS_FINAL = "validation_metrics_final" TRINING_METRICS = "training_metrics" VALIDATION_METRICS = "validation_metrics" + +SERVICE_MANAGED_CONTAINER_URI_SCHEME = "dsmc://" diff --git a/ads/aqua/deployment.py b/ads/aqua/deployment.py index 247686705..2ab25ef83 100644 --- a/ads/aqua/deployment.py +++ b/ads/aqua/deployment.py @@ -385,7 +385,6 @@ def create( ) # Give precendece to the input parameter env_var.update({"BASE_MODEL": f"{model_path_prefix}"}) - env_var.update({"MODEL": f"{model_path_prefix}"}) if vllm_params: params = f"{params} {vllm_params}" env_var.update({"PARAMS": params}) diff --git a/ads/aqua/model.py b/ads/aqua/model.py index 00b7266b4..0c7647c8e 100644 --- a/ads/aqua/model.py +++ b/ads/aqua/model.py @@ -39,6 +39,7 @@ get_artifact_path, read_file, upload_folder, + is_service_managed_container, ) from ads.common.auth import default_signer from ads.common.object_storage_details import ObjectStorageDetails @@ -802,7 +803,7 @@ def _create_model_catalog_entry( inference_container_type_smc: bool, finetuning_container_type_smc: bool, shadow_model: DataScienceModel, - ) -> str: + ) -> DataScienceModel: """Create model by reference from the object storage path Args: @@ -815,7 +816,7 @@ def _create_model_catalog_entry( shadow_model (DataScienceModel): If set, then copies all the tags and custom metadata information from the service shadow model Returns: - str: Display name of th model (This should be model ocid) + DataScienceModel: Returns Datascience model """ model_info = None model = DataScienceModel() @@ -843,11 +844,15 @@ def _create_model_catalog_entry( metadata = ModelCustomMetadata() if not inference_container or not finetuning_container: containers = self._fetch_defaults_for_hf_model(model_name=model_name) - if not inference_container and not containers.inference_container: + if not inference_container and ( + not containers or not containers.inference_container + ): raise ValueError( f"Require Inference container information. Model: {model_name} does not have associated inference container defaults. Check docs for more information on how to pass inference container" ) - if not finetuning_container and not containers.finetuning_container: + if not finetuning_container and ( + not containers or not containers.finetuning_container + ): logger.warn( f"Require Inference container information. Model: {model_name} does not have associated inference container defaults. Check docs for more information on how to pass inference container. Proceeding with model registration without the fine-tuning container information. This model will not be available for fine tuning." ) @@ -856,7 +861,7 @@ def _create_model_catalog_entry( metadata.add( key=AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME, value=inference_container or containers.inference_container, - description="Inference container mapping for SMC", + description=f"Inference container mapping for {model_name}", category="Other", ) # If SMC, the container information has to be looked up from container_index.json for the latest version @@ -877,7 +882,7 @@ def _create_model_catalog_entry( metadata.add( key=AQUA_FINETUNING_CONTAINER_METADATA_NAME, value=finetuning_container or containers.finetuning_container, - description="Fine-tuning container mapping for SMC", + description=f"Fine-tuning container mapping for {model_name}", category="Other", ) # If SMC, the container information has to be looked up from container_index.json for the latest version @@ -918,7 +923,7 @@ def _create_model_catalog_entry( .with_freeform_tags(**tags) ).create(model_by_reference=True) logger.debug(model) - return model.display_name + return model def register( self, @@ -1001,8 +1006,12 @@ def register( model_name=model_name, inference_container=inference_container, finetuning_container=finetuning_container, - inference_container_type_smc=inference_container_type_smc, - finetuning_container_type_smc=finetuning_container_type_smc, + inference_container_type_smc=True + if is_service_managed_container(inference_container) + else inference_container_type_smc, + finetuning_container_type_smc=True + if is_service_managed_container(inference_container) + else finetuning_container_type_smc, shadow_model=shadow_model_details, ) diff --git a/ads/aqua/utils.py b/ads/aqua/utils.py index d0eca6f7c..3850f6685 100644 --- a/ads/aqua/utils.py +++ b/ads/aqua/utils.py @@ -22,7 +22,10 @@ import oci from oci.data_science.models import JobRun, Model -from ads.aqua.constants import RqsAdditionalDetails +from ads.aqua.constants import ( + SERVICE_MANAGED_CONTAINER_URI_SCHEME, + RqsAdditionalDetails, +) from ads.aqua.data import AquaResourceIdentifier from ads.aqua.exception import AquaFileNotFoundError, AquaRuntimeError, AquaValueError from ads.common.auth import default_signer, AuthState @@ -786,3 +789,7 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str: ) return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path + + +def is_service_managed_container(container): + return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME) diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index 42328e59a..5789be9be 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -8,17 +8,44 @@ import unittest from dataclasses import asdict from importlib import reload -from unittest.mock import MagicMock +from unittest.mock import MagicMock, PropertyMock from mock import patch import oci +import ads.common +from ads.common.object_storage_details import ObjectStorageDetails +import ads.common.oci_client +from ads.model.service.oci_datascience_model import OCIDataScienceModel from parameterized import parameterized import ads.aqua.model import ads.config from ads.aqua.model import AquaModelApp, AquaModelSummary from ads.model.datascience_model import DataScienceModel -from ads.model.model_metadata import ModelCustomMetadata, ModelProvenanceMetadata, ModelTaxonomyMetadata +from ads.model.model_metadata import ( + ModelCustomMetadata, + ModelProvenanceMetadata, + ModelTaxonomyMetadata, +) + +import shlex +import tempfile +import huggingface_hub +import pytest + + +@pytest.fixture(autouse=True, scope="class") +def mock_auth(): + with patch("ads.common.auth.default_signer") as mock_default_signer: + yield mock_default_signer + + +@pytest.fixture(autouse=True, scope="class") +def mock_init_client(): + with patch( + "ads.common.oci_datascience.OCIDataScienceMixin.init_client" + ) as mock_client: + yield mock_client class TestDataset: @@ -70,14 +97,20 @@ class TestDataset: COMPARTMENT_ID = "ocid1.compartment.oc1.." -class TestAquaModel(unittest.TestCase): +@patch("ads.config.COMPARTMENT_OCID", "ocid1.compartment.oc1.") +@patch("ads.config.PROJECT_OCID", "ocid1.datascienceproject.oc1.iad.") +class TestAquaModel: """Contains unittests for AquaModelApp.""" - def setUp(self): + def setup_method(self): + ads.common.auth.default_signer = MagicMock() + ads.common.auth.APIKey.create_signer = MagicMock() + oci.config.validate_config = MagicMock() + ads.common.oci_client.OCIClientFactory.create_client = MagicMock() self.app = AquaModelApp() @classmethod - def setUpClass(cls): + def setup_class(cls): os.environ["CONDA_BUCKET_NS"] = "test-namespace" os.environ["ODSC_MODEL_COMPARTMENT_OCID"] = TestDataset.SERVICE_COMPARTMENT_ID reload(ads.config) @@ -85,7 +118,7 @@ def setUpClass(cls): reload(ads.aqua.model) @classmethod - def tearDownClass(cls): + def teardown_class(cls): os.environ.pop("CONDA_BUCKET_NS", None) os.environ.pop("ODSC_MODEL_COMPARTMENT_OCID", None) reload(ads.config) @@ -97,20 +130,19 @@ def tearDownClass(cls): @patch.object(DataScienceModel, "from_id") def test_create_model(self, mock_from_id, mock_validate, mock_create): mock_model = MagicMock() - mock_model.model_file_description = {"test_key":"test_value"} + mock_model.model_file_description = {"test_key": "test_value"} mock_model.display_name = "test_display_name" mock_model.description = "test_description" mock_model.freeform_tags = { - "OCI_AQUA":"ACTIVE", - "license":"test_license", - "organization":"test_organization", - "task":"test_task", - "ready_to_fine_tune":"true" + "OCI_AQUA": "ACTIVE", + "license": "test_license", + "organization": "test_organization", + "task": "test_task", + "ready_to_fine_tune": "true", } custom_metadata_list = ModelCustomMetadata() custom_metadata_list.add( - key="test_metadata_item_key", - value="test_metadata_item_value" + **{"key": "test_metadata_item_key", "value": "test_metadata_item_value"} ) mock_model.custom_metadata_list = custom_metadata_list mock_model.provenance_metadata = ModelProvenanceMetadata( @@ -118,7 +150,7 @@ def test_create_model(self, mock_from_id, mock_validate, mock_create): ) mock_from_id.return_value = mock_model - # will not copy service model + # will not copy service model self.app.create( model_id="test_model_id", project_id="test_project_id", @@ -137,28 +169,27 @@ def test_create_model(self, mock_from_id, mock_validate, mock_create): model = self.app.create( model_id="test_model_id", project_id="test_project_id", - compartment_id="test_compartment_id" + compartment_id="test_compartment_id", ) mock_from_id.assert_called_with("test_model_id") mock_validate.assert_called() - mock_create.assert_called_with( - model_by_reference=True - ) + mock_create.assert_called_with(model_by_reference=True) assert model.display_name == "test_display_name" assert model.description == "test_description" assert model.description == "test_description" assert model.freeform_tags == { - "OCI_AQUA":"ACTIVE", - "license":"test_license", - "organization":"test_organization", - "task":"test_task", - "ready_to_fine_tune":"true" + "OCI_AQUA": "ACTIVE", + "license": "test_license", + "organization": "test_organization", + "task": "test_task", + "ready_to_fine_tune": "true", } - assert model.custom_metadata_list.get( - "test_metadata_item_key" - ).value == "test_metadata_item_value" + assert ( + model.custom_metadata_list.get("test_metadata_item_key").value + == "test_metadata_item_value" + ) assert model.provenance_metadata.training_id == "test_training_id" @patch("ads.aqua.model.read_file") @@ -171,16 +202,15 @@ def test_get_model_not_fine_tuned(self, mock_from_id, mock_read_file): ds_model.display_name = "test_display_name" ds_model.description = "test_description" ds_model.freeform_tags = { - "OCI_AQUA":"ACTIVE", - "license":"test_license", - "organization":"test_organization", - "task":"test_task" + "OCI_AQUA": "ACTIVE", + "license": "test_license", + "organization": "test_organization", + "task": "test_task", } ds_model.time_created = "2024-01-19T17:57:39.158000+00:00" custom_metadata_list = ModelCustomMetadata() custom_metadata_list.add( - key="artifact_location", - value="oci://bucket@namespace/prefix/" + **{"key": "artifact_location", "value": "oci://bucket@namespace/prefix/"} ) ds_model.custom_metadata_list = custom_metadata_list @@ -196,30 +226,32 @@ def test_get_model_not_fine_tuned(self, mock_from_id, mock_read_file): ) assert asdict(aqua_model) == { - 'compartment_id': f'{ds_model.compartment_id}', - 'console_link': ( - f'https://cloud.oracle.com/data-science/models/{ds_model.id}?region={self.app.region}', + "compartment_id": f"{ds_model.compartment_id}", + "console_link": ( + f"https://cloud.oracle.com/data-science/models/{ds_model.id}?region={self.app.region}", ), - 'icon': '', - 'id': f'{ds_model.id}', - 'is_fine_tuned_model': False, - 'license': f'{ds_model.freeform_tags["license"]}', - 'model_card': f'{mock_read_file.return_value}', - 'name': f'{ds_model.display_name}', - 'organization': f'{ds_model.freeform_tags["organization"]}', - 'project_id': f'{ds_model.project_id}', - 'ready_to_deploy': True, - 'ready_to_finetune': False, - 'search_text': 'ACTIVE,test_license,test_organization,test_task', - 'tags': ds_model.freeform_tags, - 'task': f'{ds_model.freeform_tags["task"]}', - 'time_created': f'{ds_model.time_created}' + "icon": "", + "id": f"{ds_model.id}", + "is_fine_tuned_model": False, + "license": f'{ds_model.freeform_tags["license"]}', + "model_card": f"{mock_read_file.return_value}", + "name": f"{ds_model.display_name}", + "organization": f'{ds_model.freeform_tags["organization"]}', + "project_id": f"{ds_model.project_id}", + "ready_to_deploy": True, + "ready_to_finetune": False, + "search_text": "ACTIVE,test_license,test_organization,test_task", + "tags": ds_model.freeform_tags, + "task": f'{ds_model.freeform_tags["task"]}', + "time_created": f"{ds_model.time_created}", } @patch("ads.aqua.utils.query_resource") @patch("ads.aqua.model.read_file") @patch.object(DataScienceModel, "from_id") - def test_get_model_fine_tuned(self, mock_from_id, mock_read_file, mock_query_resource): + def test_get_model_fine_tuned( + self, mock_from_id, mock_read_file, mock_query_resource + ): ds_model = MagicMock() ds_model.id = "test_id" ds_model.compartment_id = "test_model_compartment_id" @@ -229,32 +261,30 @@ def test_get_model_fine_tuned(self, mock_from_id, mock_read_file, mock_query_res ds_model.model_version_set_id = "test_model_version_set_id" ds_model.model_version_set_name = "test_model_version_set_name" ds_model.freeform_tags = { - "OCI_AQUA":"ACTIVE", - "license":"test_license", - "organization":"test_organization", - "task":"test_task", - "aqua_fine_tuned_model":"test_finetuned_model" + "OCI_AQUA": "ACTIVE", + "license": "test_license", + "organization": "test_organization", + "task": "test_task", + "aqua_fine_tuned_model": "test_finetuned_model", } + self.app._service_model_details_cache.get = MagicMock(return_value=None) ds_model.time_created = "2024-01-19T17:57:39.158000+00:00" ds_model.lifecycle_state = "ACTIVE" custom_metadata_list = ModelCustomMetadata() custom_metadata_list.add( - key="artifact_location", - value="oci://bucket@namespace/prefix/" + **{"key": "artifact_location", "value": "oci://bucket@namespace/prefix/"} ) custom_metadata_list.add( - key="fine_tune_source", - value="test_fine_tuned_source_id" + **{"key": "fine_tune_source", "value": "test_fine_tuned_source_id"} ) custom_metadata_list.add( - key="fine_tune_source_name", - value="test_fine_tuned_source_name" + **{"key": "fine_tune_source_name", "value": "test_fine_tuned_source_name"} ) ds_model.custom_metadata_list = custom_metadata_list defined_metadata_list = ModelTaxonomyMetadata() defined_metadata_list["Hyperparameters"].value = { - "training_data" : "test_training_data", - "val_set_size" : "test_val_set_size" + "training_data": "test_training_data", + "val_set_size": "test_val_set_size", } ds_model.defined_metadata_list = defined_metadata_list ds_model.provenance_metadata = ModelProvenanceMetadata( @@ -269,26 +299,24 @@ def test_get_model_fine_tuned(self, mock_from_id, mock_read_file, mock_query_res job_run.id = "test_job_run_id" job_run.lifecycle_state = "SUCCEEDED" job_run.lifecycle_details = "test lifecycle details" - job_run.identifier = "test_job_id", + job_run.identifier = ("test_job_id",) job_run.display_name = "test_job_name" job_run.compartment_id = "test_job_run_compartment_id" job_infrastructure_configuration_details = MagicMock() job_infrastructure_configuration_details.shape_name = "test_shape_name" job_configuration_override_details = MagicMock() - job_configuration_override_details.environment_variables = { - "NODE_COUNT" : 1 - } - job_run.job_infrastructure_configuration_details = job_infrastructure_configuration_details + job_configuration_override_details.environment_variables = {"NODE_COUNT": 1} + job_run.job_infrastructure_configuration_details = ( + job_infrastructure_configuration_details + ) job_run.job_configuration_override_details = job_configuration_override_details log_details = MagicMock() log_details.log_id = "test_log_id" log_details.log_group_id = "test_log_group_id" job_run.log_details = log_details response.data = job_run - self.app.ds_client.get_job_run = MagicMock( - return_value = response - ) + self.app.ds_client.get_job_run = MagicMock(return_value=response) query_resource = MagicMock() query_resource.display_name = "test_display_name" @@ -304,79 +332,395 @@ def test_get_model_fine_tuned(self, mock_from_id, mock_read_file, mock_query_res mock_query_resource.assert_called() assert asdict(model) == { - 'compartment_id': f'{ds_model.compartment_id}', - 'console_link': ( - f'https://cloud.oracle.com/data-science/models/{ds_model.id}?region={self.app.region}', + "compartment_id": f"{ds_model.compartment_id}", + "console_link": ( + f"https://cloud.oracle.com/data-science/models/{ds_model.id}?region={self.app.region}", ), - 'dataset': 'test_training_data', - 'experiment': {'id': '', 'name': '', 'url': ''}, - 'icon': '', - 'id': f'{ds_model.id}', - 'is_fine_tuned_model': True, - 'job': {'id': '', 'name': '', 'url': ''}, - 'license': 'test_license', - 'lifecycle_details': f'{job_run.lifecycle_details}', - 'lifecycle_state': f'{ds_model.lifecycle_state}', - 'log': { - 'id': f'{log_details.log_id}', - 'name': f'{query_resource.display_name}', - 'url': 'https://cloud.oracle.com/logging/search?searchQuery=search ' - f'"{job_run.compartment_id}/{log_details.log_group_id}/{log_details.log_id}" | ' - f"source='{job_run.id}' | sort by datetime desc®ions={self.app.region}" + "dataset": "test_training_data", + "experiment": {"id": "", "name": "", "url": ""}, + "icon": "", + "id": f"{ds_model.id}", + "is_fine_tuned_model": True, + "job": {"id": "", "name": "", "url": ""}, + "license": "test_license", + "lifecycle_details": f"{job_run.lifecycle_details}", + "lifecycle_state": f"{ds_model.lifecycle_state}", + "log": { + "id": f"{log_details.log_id}", + "name": f"{query_resource.display_name}", + "url": "https://cloud.oracle.com/logging/search?searchQuery=search " + f'"{job_run.compartment_id}/{log_details.log_group_id}/{log_details.log_id}" | ' + f"source='{job_run.id}' | sort by datetime desc®ions={self.app.region}", + }, + "log_group": { + "id": f"{log_details.log_group_id}", + "name": f"{query_resource.display_name}", + "url": f"https://cloud.oracle.com/logging/log-groups/{log_details.log_group_id}?region={self.app.region}", + }, + "metrics": [ + {"category": "validation", "name": "validation_metrics", "scores": []}, + {"category": "training", "name": "training_metrics", "scores": []}, + { + "category": "validation", + "name": "validation_metrics_final", + "scores": [], }, - 'log_group': { - 'id': f'{log_details.log_group_id}', - 'name': f'{query_resource.display_name}', - 'url': f'https://cloud.oracle.com/logging/log-groups/{log_details.log_group_id}?region={self.app.region}' + { + "category": "training", + "name": "training_metrics_final", + "scores": [], }, - 'metrics': [ + ], + "model_card": f"{mock_read_file.return_value}", + "name": f"{ds_model.display_name}", + "organization": "test_organization", + "project_id": f"{ds_model.project_id}", + "ready_to_deploy": True, + "ready_to_finetune": False, + "search_text": "ACTIVE,test_license,test_organization,test_task,test_finetuned_model", + "shape_info": { + "instance_shape": f"{job_infrastructure_configuration_details.shape_name}", + "replica": 1, + }, + "source": {"id": "", "name": "", "url": ""}, + "tags": ds_model.freeform_tags, + "task": "test_task", + "time_created": f"{ds_model.time_created}", + "validation": {"type": "Automatic split", "value": "test_val_set_size"}, + } + + @patch("huggingface_hub.snapshot_download") + @patch("subprocess.check_call") + def test_import_shadow_model( + self, + mock_subprocess, + mock_snapshot_download, + ): + ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) + ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() + huggingface_hub.HfApi.model_info = MagicMock(return_value={}) + DataScienceModel.upload_artifact = MagicMock() + # oci.data_science.DataScienceClient = MagicMock() + DataScienceModel.sync = MagicMock() + OCIDataScienceModel.create = MagicMock() + + ds_model = DataScienceModel() + os_path = "oci://aqua-bkt@aqua-ns/prefix/path" + hf_model = "oracle/aqua-1t-mega-model" + ds_freeform_tags = { + "OCI_AQUA": "ACTIVE", + "license": "aqua-license", + "organization": "oracle", + "task": "text-generation", + } + ds_model = ( + ds_model.with_compartment_id("test_model_compartment_id") + .with_project_id("test_project_id") + .with_display_name(hf_model) + .with_description("test_description") + .with_model_version_set_id("test_model_version_set_id") + .with_freeform_tags(**ds_freeform_tags) + .with_version_id("ocid1.blah.blah") + ) + custom_metadata_list = ModelCustomMetadata() + custom_metadata_list.add( + **{"key": "deployment-container", "value": "odsc-tgi-serving"} + ) + custom_metadata_list.add( + **{"key": "evaluation-container", "value": "odsc-llm-evaluate"} + ) + ds_model.with_custom_metadata_list(custom_metadata_list) + ds_model.set_spec(ds_model.CONST_MODEL_FILE_DESCRIPTION, {}) + DataScienceModel.from_id = MagicMock(return_value=ds_model) + reload(ads.aqua.model) + app = AquaModelApp() + with tempfile.TemporaryDirectory() as tmpdir: + model: DataScienceModel = app.register( + model="ocid1.datasciencemodel.xxx.xxxx.", + os_path=os_path, + local_dir=str(tmpdir), + ) + mock_snapshot_download.assert_called_with( + repo_id=hf_model, + local_dir=f"{str(tmpdir)}/{hf_model}", + local_dir_use_symlinks=False, + ) + mock_subprocess.assert_called_with( + shlex.split( + f"oci os object bulk-upload --src-dir {str(tmpdir)}/{hf_model} --prefix prefix/path/{hf_model}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT" + ) + ) + assert model.freeform_tags == { + "aqua_custom_base_model": "true", + **ds_freeform_tags, + } + expected_metadata = [ { - 'category': 'validation', - 'name': 'validation_metrics', - 'scores': [] + "key": "evaluation-container", + "value": "odsc-llm-evaluate", + "description": "", + "category": "Other", }, { - 'category': 'training', - 'name': 'training_metrics', - 'scores': [] + "key": "deployment-container", + "value": "odsc-tgi-serving", + "description": "", + "category": "Other", }, { - 'category': 'validation', - 'name': 'validation_metrics_final', - 'scores': [] + "key": "modelDescription", + "value": "true", + "description": "model by reference flag", + "category": "Other", }, { - 'category': 'training', - 'name': 'training_metrics_final', - 'scores': [] + "key": "artifact_location", + "value": f"{os_path}/{hf_model}/", + "description": "artifact location", + "category": "Other", + }, + ] + for item in expected_metadata: + assert model.custom_metadata_list[item["key"]].to_dict() == item + assert model.version_id != ds_model.version_id + + @patch("huggingface_hub.snapshot_download") + @patch("subprocess.check_call") + def test_import_any_hf_model_no_containers_specified( + self, + mock_subprocess, + mock_snapshot_download, + ): + ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) + ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() + huggingface_hub.HfApi.model_info = MagicMock(return_value={}) + DataScienceModel.upload_artifact = MagicMock() + DataScienceModel.sync = MagicMock() + OCIDataScienceModel.create = MagicMock() + + ds_model = DataScienceModel() + os_path = "oci://aqua-bkt@aqua-ns/prefix/path" + hf_model = "oracle/aqua-1t-mega-model" + ds_freeform_tags = { + "OCI_AQUA": "ACTIVE", + "license": "aqua-license", + "organization": "oracle", + "task": "text-generation", + } + + reload(ads.aqua.model) + app = AquaModelApp() + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(ValueError): + model: DataScienceModel = app.register( + model=hf_model, + os_path=os_path, + local_dir=str(tmpdir), + ) + + @patch("huggingface_hub.snapshot_download") + @patch("subprocess.check_call") + def test_import_any_hf_model_custom_container( + self, + mock_subprocess, + mock_snapshot_download, + ): + hf_model = "oracle/aqua-1t-mega-model" + ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) + ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() + huggingface_hub.HfApi.model_info = MagicMock( + return_value=huggingface_hub.hf_api.ModelInfo( + **{ + "id": hf_model, + "license": "aqua-license", + "author": "oracle", + "pipeline_tag": "text-generation", + "private": False, + "downloads": 100, + "likes": 10, + "tags": ["text-generation"], } - ], - 'model_card': f'{mock_read_file.return_value}', - 'name': f'{ds_model.display_name}', - 'organization': 'test_organization', - 'project_id': f'{ds_model.project_id}', - 'ready_to_deploy': True, - 'ready_to_finetune': False, - 'search_text': 'ACTIVE,test_license,test_organization,test_task,test_finetuned_model', - 'shape_info': { - 'instance_shape': f'{job_infrastructure_configuration_details.shape_name}', - 'replica': 1, - }, - 'source': {'id': '', 'name': '', 'url': ''}, - 'tags': ds_model.freeform_tags, - 'task': 'test_task', - 'time_created': f'{ds_model.time_created}', - 'validation': { - 'type': 'Automatic split', - 'value': 'test_val_set_size' + ) + ) + DataScienceModel.upload_artifact = MagicMock() + DataScienceModel.sync = MagicMock() + OCIDataScienceModel.create = MagicMock() + + os_path = "oci://aqua-bkt@aqua-ns/prefix/path" + ds_freeform_tags = { + "OCI_AQUA": "active", + "organization": "oracle", + "task": "text-generation", + } + + reload(ads.aqua.model) + app = AquaModelApp() + with tempfile.TemporaryDirectory() as tmpdir: + model: DataScienceModel = app.register( + model=hf_model, + os_path=os_path, + local_dir=str(tmpdir), + inference_container="iad.ocir.io/my/own/md-container", + inference_container_type_smc=False, + finetuning_container="iad.ocir.io/my/own/ft-container", + finetuning_container_type_smc=False, + ) + mock_snapshot_download.assert_called_with( + repo_id=hf_model, + local_dir=f"{str(tmpdir)}/{hf_model}", + local_dir_use_symlinks=False, + ) + mock_subprocess.assert_called_with( + shlex.split( + f"oci os object bulk-upload --src-dir {str(tmpdir)}/{hf_model} --prefix prefix/path/{hf_model}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT" + ) + ) + assert model.freeform_tags == { + "aqua_custom_base_model": "true", + "aqua_finetuning": "true", + **ds_freeform_tags, } + expected_metadata = [ + { + "key": "evaluation-container", + "value": "odsc-llm-evaluate", + "description": "Evaluation container mapping for SMC", + "category": "Other", + }, + { + "key": "deployment-container", + "value": "iad.ocir.io/my/own/md-container", + "description": f"Inference container mapping for {hf_model}", + "category": "Other", + }, + { + "key": "finetune-container", + "value": "iad.ocir.io/my/own/ft-container", + "description": f"Fine-tuning container mapping for {hf_model}", + "category": "Other", + }, + { + "key": "modelDescription", + "value": "true", + "description": "model by reference flag", + "category": "Other", + }, + { + "key": "artifact_location", + "value": f"{os_path}/{hf_model}/", + "description": "artifact location", + "category": "Other", + }, + ] + for item in expected_metadata: + assert model.custom_metadata_list[item["key"]].to_dict() == item + assert model.version_id is None + + @patch("huggingface_hub.snapshot_download") + @patch("subprocess.check_call") + def test_import_any_hf_model_smc_container( + self, + mock_subprocess, + mock_snapshot_download, + ): + hf_model = "oracle/aqua-1t-mega-model" + ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) + ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() + huggingface_hub.HfApi.model_info = MagicMock( + return_value=huggingface_hub.hf_api.ModelInfo( + **{ + "id": hf_model, + "license": "aqua-license", + "author": "oracle", + "pipeline_tag": "text-generation", + "private": False, + "downloads": 100, + "likes": 10, + "tags": ["text-generation"], + } + ) + ) + DataScienceModel.upload_artifact = MagicMock() + DataScienceModel.sync = MagicMock() + OCIDataScienceModel.create = MagicMock() + + os_path = "oci://aqua-bkt@aqua-ns/prefix/path" + ds_freeform_tags = { + "OCI_AQUA": "active", + "organization": "oracle", + "task": "text-generation", } + reload(ads.aqua.model) + app = AquaModelApp() + with tempfile.TemporaryDirectory() as tmpdir: + model: DataScienceModel = app.register( + model=hf_model, + os_path=os_path, + local_dir=str(tmpdir), + inference_container="dsmc://md-container", + inference_container_type_smc=False, + finetuning_container="dsmc://ft-container", + finetuning_container_type_smc=False, + ) + mock_snapshot_download.assert_called_with( + repo_id=hf_model, + local_dir=f"{str(tmpdir)}/{hf_model}", + local_dir_use_symlinks=False, + ) + mock_subprocess.assert_called_with( + shlex.split( + f"oci os object bulk-upload --src-dir {str(tmpdir)}/{hf_model} --prefix prefix/path/{hf_model}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT" + ) + ) + assert model.freeform_tags == { + "aqua_custom_base_model": "true", + "aqua_finetuning": "true", + **ds_freeform_tags, + } + expected_metadata = [ + { + "key": "evaluation-container", + "value": "odsc-llm-evaluate", + "description": "Evaluation container mapping for SMC", + "category": "Other", + }, + { + "key": "deployment-container", + "value": "dsmc://md-container", + "description": f"Inference container mapping for {hf_model}", + "category": "Other", + }, + { + "key": "finetune-container", + "value": "dsmc://ft-container", + "description": f"Fine-tuning container mapping for {hf_model}", + "category": "Other", + }, + { + "key": "modelDescription", + "value": "true", + "description": "model by reference flag", + "category": "Other", + }, + { + "key": "artifact_location", + "value": f"{os_path}/{hf_model}/", + "description": "artifact location", + "category": "Other", + }, + ] + for item in expected_metadata: + assert model.custom_metadata_list[item["key"]].to_dict() == item + assert model.version_id is None + @patch("ads.aqua.model.read_file") @patch("ads.aqua.model.get_artifact_path") def test_load_license(self, mock_get_artifact_path, mock_read_file): self.app.ds_client.get_model = MagicMock() - mock_get_artifact_path.return_value = "oci://bucket@namespace/prefix/config/LICENSE.txt" + mock_get_artifact_path.return_value = ( + "oci://bucket@namespace/prefix/config/LICENSE.txt" + ) mock_read_file.return_value = "test_license" license = self.app.load_license(model_id="test_model_id") @@ -384,9 +728,7 @@ def test_load_license(self, mock_get_artifact_path, mock_read_file): mock_get_artifact_path.assert_called() mock_read_file.assert_called() - assert asdict(license) == { - 'id': 'test_model_id', 'license': 'test_license' - } + assert asdict(license) == {"id": "test_model_id", "license": "test_license"} def test_list_service_models(self): """Tests listing service models succesfully.""" @@ -426,7 +768,7 @@ def test_list_custom_models(self): results = self.app.list(TestDataset.COMPARTMENT_ID) - self.app._rqs.assert_called_with(TestDataset.COMPARTMENT_ID) + self.app._rqs.assert_called_with(TestDataset.COMPARTMENT_ID, model_type="FT") assert len(results) == 1