diff --git a/ads/common/utils.py b/ads/common/utils.py index 12242ec7c..49afcaada 100644 --- a/ads/common/utils.py +++ b/ads/common/utils.py @@ -53,6 +53,9 @@ from ads.dataset.progress import DummyProgressBar, TqdmProgressBar from . import auth as authutil +from oci import object_storage +from ads.common.oci_client import OCIClientFactory +from ads.common.object_storage_details import ObjectStorageDetails # For Model / Model Artifact libraries lib_translator = {"sklearn": "scikit-learn"} @@ -100,6 +103,9 @@ # declare custom exception class +# The number of worker processes to use in parallel for uploading individual parts of a multipart upload. +DEFAULT_PARALLEL_PROCESS_COUNT = 9 + class FileOverwriteError(Exception): # pragma: no cover pass @@ -1599,3 +1605,103 @@ def is_path_exists(uri: str, auth: Optional[Dict] = None) -> bool: if fsspec.filesystem(path_scheme, **storage_options).exists(uri): return True return False + + +def upload_to_os( + src_uri: str, + dst_uri: str, + auth: dict = None, + parallel_process_count: int = DEFAULT_PARALLEL_PROCESS_COUNT, + progressbar_description: str = "Uploading `{src_uri}` to `{dst_uri}`.", + force_overwrite: bool = False, +): + """Utilizes `oci.object_storage.Uploadmanager` to upload file to Object Storage. + + Parameters + ---------- + src_uri: str + The path to the file to upload. This should be local path. + dst_uri: str + Object Storage path, eg. `oci://my-bucket@my-tenancy/prefix``. + auth: (Dict, optional) Defaults to None. + default_signer() + parallel_process_count: (int, optional) Defaults to 3. + The number of worker processes to use in parallel for uploading individual + parts of a multipart upload. + progressbar_description: (str, optional) Defaults to `"Uploading `{src_uri}` to `{dst_uri}`"`. + Prefix for the progressbar. + force_overwrite: (bool, optional). Defaults to False. + Whether to overwrite existing files or not. + + Returns + ------- + Response: oci.response.Response + The response from multipart commit operation or the put operation. + + Raise + ----- + ValueError + When the given `dst_uri` is not a valid Object Storage path. + FileNotFoundError + When the given `src_uri` does not exist. + RuntimeError + When upload operation fails. + """ + if not os.path.exists(src_uri): + raise FileNotFoundError(f"The give src_uri: {src_uri} does not exist.") + + if not ObjectStorageDetails.is_oci_path( + dst_uri + ) or not ObjectStorageDetails.is_valid_uri(dst_uri): + raise ValueError( + f"The given dst_uri:{dst_uri} is not a valid Object Storage path." + ) + + auth = auth or authutil.default_signer() + + if not force_overwrite and is_path_exists(dst_uri): + raise FileExistsError( + f"The `{dst_uri}` exists. Please use a new file name or " + "set force_overwrite to True if you wish to overwrite." + ) + + upload_manager = object_storage.UploadManager( + object_storage_client=OCIClientFactory(**auth).object_storage, + parallel_process_count=parallel_process_count, + allow_multipart_uploads=True, + allow_parallel_uploads=True, + ) + + file_size = os.path.getsize(src_uri) + with open(src_uri, "rb") as fs: + with tqdm( + total=file_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + position=0, + leave=False, + file=sys.stdout, + desc=progressbar_description, + ) as pbar: + + def progress_callback(progress): + pbar.update(progress) + + bucket_details = ObjectStorageDetails.from_path(dst_uri) + response = upload_manager.upload_stream( + namespace_name=bucket_details.namespace, + bucket_name=bucket_details.bucket, + object_name=bucket_details.filepath, + stream_ref=fs, + progress_callback=progress_callback, + ) + + if response.status == 200: + print(f"{src_uri} has been successfully uploaded to {dst_uri}.") + else: + raise RuntimeError( + f"Failed to upload {src_uri}. Response code is {response.status}" + ) + + return response diff --git a/ads/model/artifact_uploader.py b/ads/model/artifact_uploader.py index b40840708..260761d34 100644 --- a/ads/model/artifact_uploader.py +++ b/ads/model/artifact_uploader.py @@ -94,6 +94,8 @@ def _upload(self): class SmallArtifactUploader(ArtifactUploader): + """The class helper to upload small model artifacts.""" + PROGRESS_STEPS_COUNT = 1 def _upload(self): @@ -104,6 +106,39 @@ def _upload(self): class LargeArtifactUploader(ArtifactUploader): + """ + The class helper to upload large model artifacts. + + Attributes + ---------- + artifact_path: str + The model artifact location. + artifact_zip_path: str + The uri of the zip of model artifact. + auth: dict + The default authetication is set using `ads.set_auth` API. + If you need to override the default, use the `ads.common.auth.api_keys` or + `ads.common.auth.resource_principal` to create appropriate authentication signer + and kwargs required to instantiate IdentityClient object. + bucket_uri: str + The OCI Object Storage URI where model artifacts will be copied to. + The `bucket_uri` is only necessary for uploading large artifacts which + size is greater than 2GB. Example: `oci://@/prefix/`. + dsc_model: OCIDataScienceModel + The data scince model instance. + overwrite_existing_artifact: bool + Overwrite target bucket artifact if exists. + progress: TqdmProgressBar + An instance of the TqdmProgressBar. + region: str + The destination Object Storage bucket region. + By default the value will be extracted from the `OCI_REGION_METADATA` environment variables. + remove_existing_artifact: bool + Wether artifacts uploaded to object storage bucket need to be removed or not. + upload_manager: UploadManager + The uploadManager simplifies interaction with the Object Storage service. + """ + PROGRESS_STEPS_COUNT = 4 def __init__( @@ -115,6 +150,7 @@ def __init__( region: Optional[str] = None, overwrite_existing_artifact: Optional[bool] = True, remove_existing_artifact: Optional[bool] = True, + parallel_process_count: int = utils.DEFAULT_PARALLEL_PROCESS_COUNT, ): """Initializes `LargeArtifactUploader` instance. @@ -139,7 +175,9 @@ def __init__( overwrite_existing_artifact: (bool, optional). Defaults to `True`. Overwrite target bucket artifact if exists. remove_existing_artifact: (bool, optional). Defaults to `True`. - Wether artifacts uploaded to object storage bucket need to be removed or not. + Whether artifacts uploaded to object storage bucket need to be removed or not. + parallel_process_count: (int, optional). + The number of worker processes to use in parallel for uploading individual parts of a multipart upload. """ if not bucket_uri: raise ValueError("The `bucket_uri` must be provided.") @@ -150,36 +188,45 @@ def __init__( self.bucket_uri = bucket_uri self.overwrite_existing_artifact = overwrite_existing_artifact self.remove_existing_artifact = remove_existing_artifact + self._parallel_process_count = parallel_process_count def _upload(self): """Uploads model artifacts to the model catalog.""" self.progress.update("Copying model artifact to the Object Storage bucket") - try: - bucket_uri = self.bucket_uri - bucket_uri_file_name = os.path.basename(bucket_uri) + bucket_uri = self.bucket_uri + bucket_uri_file_name = os.path.basename(bucket_uri) - if not bucket_uri_file_name: - bucket_uri = os.path.join(bucket_uri, f"{self.dsc_model.id}.zip") - elif not bucket_uri.lower().endswith(".zip"): - bucket_uri = f"{bucket_uri}.zip" + if not bucket_uri_file_name: + bucket_uri = os.path.join(bucket_uri, f"{self.dsc_model.id}.zip") + elif not bucket_uri.lower().endswith(".zip"): + bucket_uri = f"{bucket_uri}.zip" - bucket_file_name = utils.copy_file( - self.artifact_zip_path, - bucket_uri, - force_overwrite=self.overwrite_existing_artifact, - auth=self.auth, - progressbar_description="Copying model artifact to the Object Storage bucket", - ) - except FileExistsError: + if not self.overwrite_existing_artifact and utils.is_path_exists( + uri=bucket_uri, auth=self.auth + ): raise FileExistsError( - f"The `{self.bucket_uri}` exists. Please use a new file name or " + f"The bucket_uri=`{self.bucket_uri}` exists. Please use a new file name or " "set `overwrite_existing_artifact` to `True` if you wish to overwrite." ) + + try: + utils.upload_to_os( + src_uri=self.artifact_zip_path, + dst_uri=bucket_uri, + auth=self.auth, + parallel_process_count=self._parallel_process_count, + force_overwrite=self.overwrite_existing_artifact, + progressbar_description="Copying model artifact to the Object Storage bucket.", + ) + except Exception as ex: + raise RuntimeError( + f"Failed to upload model artifact to the given Object Storage path `{self.bucket_uri}`." + f"See Exception: {ex}" + ) + self.progress.update("Exporting model artifact to the model catalog") - self.dsc_model.export_model_artifact( - bucket_uri=bucket_file_name, region=self.region - ) + self.dsc_model.export_model_artifact(bucket_uri=bucket_uri, region=self.region) if self.remove_existing_artifact: self.progress.update( diff --git a/ads/model/datascience_model.py b/ads/model/datascience_model.py index 4a5cdc120..8bbf6d0da 100644 --- a/ads/model/datascience_model.py +++ b/ads/model/datascience_model.py @@ -35,7 +35,7 @@ _MAX_ARTIFACT_SIZE_IN_BYTES = 2147483648 # 2GB -class ModelArtifactSizeError(Exception): # pragma: no cover +class ModelArtifactSizeError(Exception): # pragma: no cover def __init__(self, max_artifact_size: str): super().__init__( f"The model artifacts size is greater than `{max_artifact_size}`. " @@ -562,6 +562,8 @@ def create(self, **kwargs) -> "DataScienceModel": and kwargs required to instantiate IdentityClient object. timeout: (int, optional). Defaults to 10 seconds. The connection timeout in seconds for the client. + parallel_process_count: (int, optional). + The number of worker processes to use in parallel for uploading individual parts of a multipart upload. Returns ------- @@ -607,6 +609,7 @@ def create(self, **kwargs) -> "DataScienceModel": region=kwargs.pop("region", None), auth=kwargs.pop("auth", None), timeout=kwargs.pop("timeout", None), + parallel_process_count=kwargs.pop("parallel_process_count", None), ) # Sync up model @@ -623,6 +626,7 @@ def upload_artifact( overwrite_existing_artifact: Optional[bool] = True, remove_existing_artifact: Optional[bool] = True, timeout: Optional[int] = None, + parallel_process_count: int = utils.DEFAULT_PARALLEL_PROCESS_COUNT, ) -> None: """Uploads model artifacts to the model catalog. @@ -646,6 +650,8 @@ def upload_artifact( Wether artifacts uploaded to object storage bucket need to be removed or not. timeout: (int, optional). Defaults to 10 seconds. The connection timeout in seconds for the client. + parallel_process_count: (int, optional) + The number of worker processes to use in parallel for uploading individual parts of a multipart upload. """ # Upload artifact to the model catalog if not self.artifact: @@ -676,6 +682,7 @@ def upload_artifact( bucket_uri=bucket_uri, overwrite_existing_artifact=overwrite_existing_artifact, remove_existing_artifact=remove_existing_artifact, + parallel_process_count=parallel_process_count, ) else: artifact_uploader = SmallArtifactUploader( diff --git a/ads/model/generic_model.py b/ads/model/generic_model.py index 8d3b3c875..127b8ad1b 100644 --- a/ads/model/generic_model.py +++ b/ads/model/generic_model.py @@ -1825,6 +1825,7 @@ def save( remove_existing_artifact: Optional[bool] = True, model_version_set: Optional[Union[str, ModelVersionSet]] = None, version_label: Optional[str] = None, + parallel_process_count: int = utils.DEFAULT_PARALLEL_PROCESS_COUNT, **kwargs, ) -> str: """Saves model artifacts to the model catalog. @@ -1856,6 +1857,8 @@ def save( The model version set OCID, or model version set name, or `ModelVersionSet` instance. version_label: (str, optional). Defaults to None. The model version lebel. + parallel_process_count: (int, optional) + The number of worker processes to use in parallel for uploading individual parts of a multipart upload. kwargs: project_id: (str, optional). Project OCID. If not specified, the value will be taken either @@ -1880,6 +1883,18 @@ def save( ------- str The model id. + + Examples + -------- + Example for saving large model artifacts (>2GB): + >>> model.save( + ... bucket_uri="oci://my-bucket@my-tenancy/", + ... overwrite_existing_artifact=True, + ... remove_existing_artifact=True, + ... remove_existing_artifact=True, + ... parallel_process_count=9, + ... ) + """ # Set default display_name if not specified - randomly generated easy to remember name generated if not display_name: @@ -1951,6 +1966,7 @@ def save( bucket_uri=bucket_uri, overwrite_existing_artifact=overwrite_existing_artifact, remove_existing_artifact=remove_existing_artifact, + parallel_process_count=parallel_process_count, **kwargs, ) @@ -2151,10 +2167,10 @@ def deploy( "log_id", None ) or self.properties.deployment_predict_log_id, - deployment_image = getattr(existing_runtime, "image", None) + deployment_image=getattr(existing_runtime, "image", None) or self.properties.deployment_image, - deployment_instance_subnet_id = existing_infrastructure.subnet_id - or self.properties.deployment_instance_subnet_id + deployment_instance_subnet_id=existing_infrastructure.subnet_id + or self.properties.deployment_instance_subnet_id, ).to_dict() property_dict.update(override_properties) @@ -2228,25 +2244,18 @@ def deploy( runtime = None if self.properties.deployment_image: - image_digest = ( - kwargs.pop("image_digest", None) - or getattr(existing_runtime, "image_digest", None) - ) - cmd = ( - kwargs.pop("cmd", []) - or getattr(existing_runtime, "cmd", []) + image_digest = kwargs.pop("image_digest", None) or getattr( + existing_runtime, "image_digest", None ) - entrypoint = ( - kwargs.pop("entrypoint", []) - or getattr(existing_runtime, "entrypoint", []) + cmd = kwargs.pop("cmd", []) or getattr(existing_runtime, "cmd", []) + entrypoint = kwargs.pop("entrypoint", []) or getattr( + existing_runtime, "entrypoint", [] ) - server_port = ( - kwargs.pop("server_port", None) - or getattr(existing_runtime, "server_port", None) + server_port = kwargs.pop("server_port", None) or getattr( + existing_runtime, "server_port", None ) - health_check_port = ( - kwargs.pop("health_check_port", None) - or getattr(existing_runtime, "health_check_port", None) + health_check_port = kwargs.pop("health_check_port", None) or getattr( + existing_runtime, "health_check_port", None ) runtime = ( ModelDeploymentContainerRuntime() @@ -2854,6 +2863,7 @@ def upload_artifact( uri: str, auth: Optional[Dict] = None, force_overwrite: Optional[bool] = False, + parallel_process_count: int = utils.DEFAULT_PARALLEL_PROCESS_COUNT, ) -> None: """Uploads model artifacts to the provided `uri`. The artifacts will be zipped before uploading. @@ -2873,6 +2883,8 @@ def upload_artifact( authentication signer and kwargs required to instantiate IdentityClient object. force_overwrite: bool Overwrite target_dir if exists. + parallel_process_count: (int, optional) + The number of worker processes to use in parallel for uploading individual parts of a multipart upload. """ if not uri: raise ValueError("The `uri` must be provided.") @@ -2887,19 +2899,34 @@ def upload_artifact( uri = os.path.join(uri, f"{self.model_id}.zip") tmp_artifact_zip_path = None + progressbar_description = f"Uploading an artifact ZIP archive to {uri}." try: # Zip artifacts tmp_artifact_zip_path = zip_artifact(self.artifact_dir) # Upload artifacts to the provided destination - utils.copy_file( - uri_src=tmp_artifact_zip_path, - uri_dst=uri, - auth=auth, - force_overwrite=force_overwrite, - progressbar_description=f"Uploading an artifact ZIP archive to the {uri}", + if ObjectStorageDetails.is_oci_path( + uri + ) and ObjectStorageDetails.is_valid_uri(uri): + utils.upload_to_os( + src_uri=tmp_artifact_zip_path, + dst_uri=uri, + auth=auth, + parallel_process_count=parallel_process_count, + progressbar_description=progressbar_description, + ) + else: + utils.copy_file( + uri_src=tmp_artifact_zip_path, + uri_dst=uri, + auth=auth, + force_overwrite=force_overwrite, + progressbar_description=progressbar_description, + ) + except Exception as ex: + raise RuntimeError( + f"Failed to upload model artifact to the given Object Storage path `{uri}`." + f"See Exception: {ex}" ) - except Exception: - raise finally: if tmp_artifact_zip_path: os.remove(tmp_artifact_zip_path) diff --git a/tests/unitary/default_setup/common/test_common_utils.py b/tests/unitary/default_setup/common/test_common_utils.py index a6df31d97..634f844b6 100644 --- a/tests/unitary/default_setup/common/test_common_utils.py +++ b/tests/unitary/default_setup/common/test_common_utils.py @@ -10,7 +10,7 @@ import sys import tempfile from datetime import datetime -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, ANY import numpy as np import pandas as pd @@ -29,7 +29,9 @@ folder_size, human_size, remove_file, + upload_to_os, ) +from oci import object_storage DEFAULT_SIGNER_CONF = {"config": {}} @@ -498,3 +500,84 @@ def test_extract_region(self, input_params, expected_result): return_value={"config": {"region": "default_signer_region"}}, ): assert extract_region(input_params["auth"]) == expected_result + + @patch("ads.common.auth.default_signer") + @patch("os.path.exists") + def test_upload_to_os_with_invalid_src_uri( + self, mock_file_exists, mock_default_signer + ): + """Ensures upload_to_os fails when the given `src_uri` does not exist.""" + mock_default_signer.return_value = DEFAULT_SIGNER_CONF + mock_file_exists.return_value = False + with pytest.raises(FileNotFoundError): + upload_to_os(src_uri="fake_uri", dst_uri="fake_uri") + + @patch("ads.common.auth.default_signer") + @patch("os.path.exists") + @patch("ads.common.utils.is_path_exists") + def test_upload_to_os_with_invalid_dst_uri( + self, mock_is_path_exists, mock_file_exists, mock_default_signer + ): + """ + Ensures upload_to_os fails when the given `dst_uri` is invalid. + Ensures upload_to_os fails in case of destination file already exists and + `force_overwrite` flag is not set to True. + """ + mock_default_signer.return_value = DEFAULT_SIGNER_CONF + mock_file_exists.return_value = True + mock_is_path_exists = True + with pytest.raises(ValueError): + upload_to_os(src_uri="fake_uri", dst_uri="This is an invalid oci path.") + + with pytest.raises(FileExistsError): + upload_to_os( + src_uri="fake_uri", + dst_uri="oci://my-bucket@my-tenancy/prefix", + force_overwrite=False, + ) + + @patch("ads.common.oci_client.OCIClientFactory.object_storage") + @patch("ads.common.utils.is_path_exists") + @patch.object(object_storage.UploadManager, "upload_stream") + @patch.object(object_storage.UploadManager, "__init__", return_value=None) + def test_upload_to_os( + self, + mock_init, + mock_upload, + mock_is_path_exists, + mock_client, + ): + """Tests upload_to_os successfully.""" + + class MockResponse: + def __init__(self, status_code): + self.status = status_code + + mock_upload.return_value = MockResponse(200) + dst_namespace = "my-tenancy" + dst_bucket = "my-bucket" + dst_prefix = "prefix" + parallel_process_count = 3 + uri_src = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "./test_files/archive/1.txt" + ) + response = upload_to_os( + src_uri=uri_src, + dst_uri=f"oci://{dst_bucket}@{dst_namespace}/{dst_prefix}", + force_overwrite=True, + parallel_process_count=parallel_process_count, + ) + mock_init.assert_called_with( + object_storage_client=mock_client, + parallel_process_count=parallel_process_count, + allow_multipart_uploads=True, + allow_parallel_uploads=True, + ) + mock_upload.assert_called_with( + namespace_name=dst_namespace, + bucket_name=dst_bucket, + object_name=dst_prefix, + stream_ref=ANY, + progress_callback=ANY, + ) + assert response.status == 200 diff --git a/tests/unitary/default_setup/model/test_artifact_uploader.py b/tests/unitary/default_setup/model/test_artifact_uploader.py index 48a41e7f5..bc9daeabf 100644 --- a/tests/unitary/default_setup/model/test_artifact_uploader.py +++ b/tests/unitary/default_setup/model/test_artifact_uploader.py @@ -13,6 +13,9 @@ import pytest from ads.model.artifact_uploader import LargeArtifactUploader, SmallArtifactUploader from ads.model.common.utils import zip_artifact +from ads.common.auth import default_signer +from ads.common.utils import DEFAULT_PARALLEL_PROCESS_COUNT +from oci import object_storage MODEL_OCID = "ocid1.datasciencemodel.oc1.xxx" @@ -60,7 +63,6 @@ def test__init__(self): # Ensures the LargeArtifactUploader can be successfully initialized with patch("os.path.exists", return_value=True): - with pytest.raises(ValueError, match="The `bucket_uri` must be provided."): lg_artifact_uploader = LargeArtifactUploader( dsc_model=self.mock_dsc_model, @@ -71,11 +73,11 @@ def test__init__(self): overwrite_existing_artifact=False, remove_existing_artifact=False, ) - + auth = default_signer() lg_artifact_uploader = LargeArtifactUploader( dsc_model=self.mock_dsc_model, artifact_path="existing_path", - auth=self.mock_auth, + auth=auth, region=self.mock_region, bucket_uri="test_bucket_uri", overwrite_existing_artifact=False, @@ -85,14 +87,17 @@ def test__init__(self): assert lg_artifact_uploader.artifact_path == "existing_path" assert lg_artifact_uploader.artifact_zip_path == None assert lg_artifact_uploader.progress == None - assert lg_artifact_uploader.auth == self.mock_auth + assert lg_artifact_uploader.auth == auth assert lg_artifact_uploader.region == self.mock_region assert lg_artifact_uploader.bucket_uri == "test_bucket_uri" assert lg_artifact_uploader.overwrite_existing_artifact == False assert lg_artifact_uploader.remove_existing_artifact == False + assert ( + lg_artifact_uploader._parallel_process_count + == DEFAULT_PARALLEL_PROCESS_COUNT + ) def test_prepare_artiact_tmp_zip(self): - # Tests case when a folder provided as artifacts location with patch("ads.model.common.utils.zip_artifact") as mock_zip_artifact: mock_zip_artifact.return_value = "test_artifact.zip" @@ -167,50 +172,48 @@ def test_upload_small_artifact(self): mock_remove_artiact_tmp_zip.assert_called() self.mock_dsc_model.create_model_artifact.assert_called() - def test_upload_large_artifact(self): - with tempfile.TemporaryDirectory() as tmp_artifact_dir: - test_bucket_file_name = os.path.join(tmp_artifact_dir, f"{MODEL_OCID}.zip") - # Case when artifact will be created and left in the TMP folder + @patch("ads.common.utils.is_path_exists") + @patch("ads.common.utils.upload_to_os") + def test_upload_large_artifact(self, mock_upload, mock_is_path_exists): + # Case when artifact already exists and overwrite_existing_artifact==True + dest_path = "oci://my-bucket@my-namespace/my-artifact-path" + test_bucket_file_name = os.path.join(dest_path, f"{MODEL_OCID}.zip") + mock_is_path_exists.return_value = True + auth = default_signer() + artifact_uploader = LargeArtifactUploader( + dsc_model=self.mock_dsc_model, + artifact_path=self.mock_artifact_zip_path, + bucket_uri=dest_path + "/", + auth=auth, + region=self.mock_region, + overwrite_existing_artifact=True, + remove_existing_artifact=False, + ) + artifact_uploader.upload() + mock_upload.assert_called_with( + src_uri=self.mock_artifact_zip_path, + dst_uri=test_bucket_file_name, + auth=auth, + parallel_process_count=DEFAULT_PARALLEL_PROCESS_COUNT, + force_overwrite=True, + progressbar_description="Copying model artifact to the Object Storage bucket.", + ) + self.mock_dsc_model.export_model_artifact.assert_called_with( + bucket_uri=test_bucket_file_name, region=self.mock_region + ) + + # Case when artifact already exists and overwrite_existing_artifact==False + with pytest.raises(FileExistsError): artifact_uploader = LargeArtifactUploader( dsc_model=self.mock_dsc_model, - artifact_path=self.mock_artifact_path, - bucket_uri=tmp_artifact_dir + "/", - auth=self.mock_auth, + artifact_path=self.mock_artifact_zip_path, + bucket_uri=dest_path + "/", + auth=default_signer(), region=self.mock_region, overwrite_existing_artifact=False, remove_existing_artifact=False, ) artifact_uploader.upload() - self.mock_dsc_model.export_model_artifact.assert_called_with( - bucket_uri=test_bucket_file_name, region=self.mock_region - ) - assert os.path.exists(test_bucket_file_name) - - # Case when artifact already exists and overwrite_existing_artifact==False - with pytest.raises(FileExistsError): - artifact_uploader = LargeArtifactUploader( - dsc_model=self.mock_dsc_model, - artifact_path=self.mock_artifact_path, - bucket_uri=tmp_artifact_dir + "/", - auth=self.mock_auth, - region=self.mock_region, - overwrite_existing_artifact=False, - remove_existing_artifact=False, - ) - artifact_uploader.upload() - - # Case when artifact already exists and overwrite_existing_artifact==True - artifact_uploader = LargeArtifactUploader( - dsc_model=self.mock_dsc_model, - artifact_path=self.mock_artifact_path, - bucket_uri=tmp_artifact_dir + "/", - auth=self.mock_auth, - region=self.mock_region, - overwrite_existing_artifact=True, - remove_existing_artifact=True, - ) - artifact_uploader.upload() - assert not os.path.exists(test_bucket_file_name) def test_zip_artifact_fail(self): with pytest.raises(ValueError, match="The `artifact_dir` must be provided."): diff --git a/tests/unitary/default_setup/model/test_datascience_model.py b/tests/unitary/default_setup/model/test_datascience_model.py index 5c0b4d673..cc357f5e6 100644 --- a/tests/unitary/default_setup/model/test_datascience_model.py +++ b/tests/unitary/default_setup/model/test_datascience_model.py @@ -156,7 +156,6 @@ class TestDataScienceModel: - DEFAULT_PROPERTIES_PAYLOAD = { "compartmentId": DSC_MODEL_PAYLOAD["compartmentId"], "projectId": DSC_MODEL_PAYLOAD["projectId"], @@ -368,6 +367,7 @@ def test_create_success( bucket_uri="test_bucket_uri", overwrite_existing_artifact=False, remove_existing_artifact=False, + parallel_process_count=3, ) mock_oci_dsc_model_create.assert_called() mock_create_model_provenance.assert_called_with( @@ -380,6 +380,7 @@ def test_create_success( region=None, auth=None, timeout=None, + parallel_process_count=3, ) mock_sync.assert_called() assert self.prepare_dict(result.to_dict()["spec"]) == self.prepare_dict( @@ -622,6 +623,7 @@ def test_upload_artifact(self): bucket_uri="test_bucket_uri", overwrite_existing_artifact=False, remove_existing_artifact=False, + parallel_process_count=utils.DEFAULT_PARALLEL_PROCESS_COUNT, ) mock_upload.assert_called() @@ -659,7 +661,6 @@ def test_download_artifact(self): LargeArtifactDownloader, "__init__", return_value=None ) as mock_init: with patch.object(LargeArtifactDownloader, "download") as mock_download: - # If artifact is large and bucket_uri not provided with pytest.raises(ModelArtifactSizeError): self.mock_dsc_model.download_artifact(target_dir="test_target_dir") diff --git a/tests/unitary/with_extras/model/test_generic_model.py b/tests/unitary/with_extras/model/test_generic_model.py index f0e1a1f63..7773b8a84 100644 --- a/tests/unitary/with_extras/model/test_generic_model.py +++ b/tests/unitary/with_extras/model/test_generic_model.py @@ -368,6 +368,7 @@ def test_save(self, mock_dsc_model_create, mock__random_display_name): bucket_uri=None, overwrite_existing_artifact=True, remove_existing_artifact=True, + parallel_process_count=utils.DEFAULT_PARALLEL_PROCESS_COUNT, ) def test_save_not_implemented_error(self): @@ -606,7 +607,10 @@ def test_deploy_success(self, mock_deploy): "ocpus": input_dict["deployment_ocpus"], "memory_in_gbs": input_dict["deployment_memory_in_gbs"], } - assert result.infrastructure.subnet_id == input_dict["deployment_instance_subnet_id"] + assert ( + result.infrastructure.subnet_id + == input_dict["deployment_instance_subnet_id"] + ) assert result.runtime.image == input_dict["deployment_image"] assert result.runtime.entrypoint == input_dict["entrypoint"] assert result.runtime.server_port == input_dict["server_port"] @@ -994,9 +998,7 @@ def test_from_model_deployment( compartment_id="test_compartment_id", ) - mock_from_id.assert_called_with( - test_model_deployment_id - ) + mock_from_id.assert_called_with(test_model_deployment_id) mock_from_model_catalog.assert_called_with( model_id=test_model_id, model_file_name="test.pkl", @@ -1049,9 +1051,7 @@ def test_from_model_deployment_fail( remove_existing_artifact=True, compartment_id="test_compartment_id", ) - mock_from_id.assert_called_with( - test_model_deployment_id - ) + mock_from_id.assert_called_with(test_model_deployment_id) @patch.object(ModelDeployment, "update") @patch.object(ModelDeployment, "from_id") @@ -1086,9 +1086,7 @@ def test_update_deployment_class_level( poll_interval=200, ) - mock_from_id.assert_called_with( - test_model_deployment_id - ) + mock_from_id.assert_called_with(test_model_deployment_id) mock_update.assert_called_with( properties=None,