Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/multi model artifact handler #869

Merged
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6d9fa4a
Initial commit
YashPandit4u May 27, 2024
6a930f1
rename in main module file
YashPandit4u May 27, 2024
5f3b316
Logger used for prints, error handling improved, one extra file creat…
YashPandit4u Jun 1, 2024
802e438
Merge branch 'oracle:main' into feature/multi-model-artifact-handler
YashPandit4u Jun 1, 2024
35e8464
Reformatted using black.
YashPandit4u Jun 1, 2024
e625107
Merge branch 'feature/multi-model-artifact-handler' of https://github…
YashPandit4u Jun 1, 2024
74a235e
Separate logger used.
YashPandit4u Jun 1, 2024
848972e
Added python docs for all methods.
YashPandit4u Jun 1, 2024
d35b917
Added class DataScienceModelCollection that extends from DataScienceM…
YashPandit4u Jun 5, 2024
3ac5338
removed old model description class
YashPandit4u Jun 5, 2024
822c7e8
formatted using black
YashPandit4u Jun 5, 2024
3d4d950
black formatter used and one return type added.
YashPandit4u Jun 5, 2024
d4ef023
Added add_artifact and remove_artifact method in main DataScienceMode…
YashPandit4u Jun 11, 2024
c40ea8b
Removed new added class.
YashPandit4u Jun 11, 2024
d0309f4
Added uri based approach
YashPandit4u Jun 13, 2024
19ec921
Added unit tests.
YashPandit4u Jun 13, 2024
9089028
Changed the pydocs according to ads specifications
YashPandit4u Jun 13, 2024
464cb37
Merge branch 'main' into feature/multi-model-artifact-handler
YashPandit4u Jun 13, 2024
49be8f8
replaces regex with normal splitting for uri
YashPandit4u Jun 13, 2024
a2894a5
Merge branch 'feature/multi-model-artifact-handler' of https://github…
YashPandit4u Jun 13, 2024
5d327cd
removed default_signer
YashPandit4u Jun 13, 2024
b7e7d74
Used ObjectStorageDetails.from_path(uri) for url decoding.
YashPandit4u Jun 13, 2024
4418bf8
namespace, bucket, prefix way added again.
YashPandit4u Jun 14, 2024
e2e7099
Merge branch 'main' into feature/multi-model-artifact-handler
YashPandit4u Jun 14, 2024
4cc8823
Given options for both uri and (namespace, bucket) in add_artifact, r…
YashPandit4u Jun 15, 2024
7f3cc7f
Updated python docs
YashPandit4u Jun 15, 2024
32f4826
Ran black formatter.
YashPandit4u Jun 15, 2024
9f69e35
Ran black formatter of UTs.
YashPandit4u Jun 15, 2024
ad0ce2e
Added uri UTs.
YashPandit4u Jun 15, 2024
eec2b4d
Prefix null check added.
YashPandit4u Jun 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions ads/model/datascience_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
ModelProvenanceNotFoundError,
OCIDataScienceModel,
)
from ads.common import oci_client as oc

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1466,3 +1467,183 @@ def _download_file_description_artifact(self) -> Tuple[Union[str, List[str]], in
bucket_uri.append(uri)

return bucket_uri[0] if len(bucket_uri) == 1 else bucket_uri, artifact_size

def add_artifact(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Could you please add an examples section for the add and delete artifacts feature?

self,
uri: str,
files: Optional[List[str]] = None,
):
"""
Adds information about objects in a specified bucket to the model description JSON.

Parameters
----------
uri (str): The URI representing the location of the artifact in OCI object storage.
files (list of str, optional): A list of file names to include in the model description.
If provided, only objects with matching file names will be included. Defaults to None.

Returns
-------
None

Raises
------
ValueError: If no files are found to add to the model description.

Note
----
- If `files` is not provided, it retrieves information about all objects in the bucket.
If `files` is provided, it only retrieves information about objects with matching file names.
- If no objects are found to add to the model description, a ValueError is raised.
"""

bucket, namespace, prefix = self._extract_oci_uri_components(uri)

if self.model_file_description == None:
self.empty_json = {
"version": "1.0",
"type": "modelOSSReferenceDescription",
"models": [],
}
self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, self.empty_json)

# Get object storage client
self.object_storage_client = oc.OCIClientFactory(**(self.dsc_model.auth)).object_storage

# Remove if the model already exists
self.remove_artifact(uri=uri)

def check_if_file_exists(fileName):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: I would not recommend to create nested functions, it would be hard to write unit tests for such methods. I would rather recommend to move such methods into utils or outside the class.

isExists = False
try:
headResponse = self.object_storage_client.head_object(
namespace, bucket, object_name=fileName
)
if headResponse.status == 200:
isExists = True
except Exception as e:
if hasattr(e, "status") and e.status == 404:
logger.error(f"File not found in bucket: {fileName}")
else:
logger.error(f"An error occured: {e}")
return isExists

# Function to un-paginate the api call with while loop
def list_obj_versions_unpaginated():
objectStorageList = []
has_next_page, opc_next_page = True, None
while has_next_page:
response = self.object_storage_client.list_object_versions(
namespace_name=namespace,
bucket_name=bucket,
prefix=prefix,
fields="name,size",
page=opc_next_page,
)
objectStorageList.extend(response.data.items)
has_next_page = response.has_next_page
opc_next_page = response.next_page
return objectStorageList

# Fetch object details and put it into the objects variable
objectStorageList = []
if files == None:
objectStorageList = list_obj_versions_unpaginated()
else:
for fileName in files:
if check_if_file_exists(fileName=fileName):
objectStorageList.append(
self.object_storage_client.list_object_versions(
namespace_name=namespace,
bucket_name=bucket,
prefix=fileName,
fields="name,size",
).data.items[0]
)

objects = [
{"name": obj.name, "version": obj.version_id, "sizeInBytes": obj.size}
for obj in objectStorageList
if obj.size > 0
]

if len(objects) == 0:
error_message = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT:

error_message = (
        f"No files found to add in the bucket '{bucket}' within the namespace '{namespace}' "
        f"and prefix '{prefix}'. Expected file names: {files}"
    )

f"No files to add in the bucket: {bucket} with namespace: {namespace} "
f"and prefix: {prefix}. File names: {files}"
)
logger.error(error_message)
raise ValueError(error_message)

tmp_model_file_description = self.model_file_description
tmp_model_file_description["models"].append(
{
"namespace": namespace,
"bucketName": bucket,
"prefix": prefix,
"objects": objects,
}
)
self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, tmp_model_file_description)

def remove_artifact(self, uri: str):
"""
Removes information about objects in a specified bucket from the model description JSON.

Parameters
----------
uri (str): The URI representing the location of the artifact in OCI object storage.

Returns
-------
None

Raises
------
ValueError: If the model description JSON is None.
"""

bucket, namespace, prefix = self._extract_oci_uri_components(uri)

def findModelIdx():
for idx, model in enumerate(self.model_file_description["models"]):
if (
model["namespace"],
model["bucketName"],
(model["prefix"] if ("prefix" in model) else None),
) == (namespace, bucket, prefix):
return idx
return -1

if self.model_file_description == None:
return

modelSearchIdx = findModelIdx()
if modelSearchIdx == -1:
return
else:
# model found case
self.model_file_description["models"].pop(modelSearchIdx)

def _extract_oci_uri_components(self, uri: str):
if not uri.startswith("oci://"):
raise ValueError("Invalid URI format")

# Remove the "oci://" prefix
uri = uri[len("oci://"):]

# Split by "@" to get bucket_name and the rest
bucket_and_rest = uri.split("@", 1)
if len(bucket_and_rest) != 2:
raise ValueError("Invalid URI format")

bucket_name = bucket_and_rest[0]

# Split the rest by "/" to get namespace and prefix
namespace_and_prefix = bucket_and_rest[1].split("/", 1)
if len(namespace_and_prefix) == 1:
namespace, prefix = namespace_and_prefix[0], ""
else:
namespace, prefix = namespace_and_prefix

return bucket_name, namespace, None if prefix == '' else prefix
129 changes: 128 additions & 1 deletion tests/unitary/default_setup/model/test_datascience_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
OCIDataScienceModel,
)
from oci.data_science.models import ModelProvenance
from oci.object_storage.models.object_version_summary import ObjectVersionSummary
from ads.config import AQUA_SERVICE_MODELS_BUCKET as SERVICE_MODELS_BUCKET


MODEL_OCID = "ocid1.datasciencemodel.oc1.iad.<unique_ocid>"

OCI_MODEL_PAYLOAD = {
Expand Down Expand Up @@ -163,6 +163,66 @@
"Content-Length": _MAX_ARTIFACT_SIZE_IN_BYTES + 100,
}

MODEL_BY_REF_JSON = {
"version": "1.0",
"type": "modelOSSReferenceDescription",
"models": [
{
"namespace": "ociodscdev",
"bucketName": "unzip-multi-model",
"prefix": "model-linear-1",
"objects": [
{
"name": "model-linear-1/linear-1.pkl",
"version": "ee260f4a-920a-4b4e-974a-c13a1032558e",
"sizeInBytes": 565
}
]
},
{
"namespace": "ociodscdev",
"bucketName": "unzip-multi-model",
"prefix": "model-linear-2",
"objects": [
{
"name": "model-linear-2/linear-2.pkl",
"version": "dc26a7d2-8041-4b37-8ed0-9e8c10869340",
"sizeInBytes": 565
}
]
},
{
"namespace": "ociodscdev",
"bucketName": "unzip-multi-model",
"prefix": "model-linear-3",
"objects": [
{
"name": "model-linear-3/linear-3.pkl",
"version": "a22c1211-f7d4-4fd4-96d8-4e3a048c5cf7",
"sizeInBytes": 565
}
]
},
{
"namespace": "ociodscdev",
"bucketName": "unzip-multi-model",
"prefix": "",
"objects": [
{
"name": "runtime.yaml",
"version": "30afb1a6-ab1f-42a3-95e3-09f61a0046fd",
"sizeInBytes": 334
},
{
"name": "score.py",
"version": "c4ccaf96-05be-4174-ac3b-15dce2f558fe",
"sizeInBytes": 772
}
]
}
]
}
CONST_MODEL_FILE_DESCRIPTION = "modelDescription"

class TestDataScienceModel:
DEFAULT_PROPERTIES_PAYLOAD = {
Expand Down Expand Up @@ -1070,3 +1130,70 @@ def test_download_artifact_for_model_created_by_reference(
)

mock_large_download.assert_called()



@patch("ads.common.oci_client.OCIClientFactory")
def test_add_artifact(self, mock_oci_client_factory):
r = ObjectVersionSummary()
r.name = "model-linear-2/linear-2.pkl"
r.size = 566
r.time_modified = "2024-04-22T12:34:26.670000+00:00"
r.version_id = "dc26a7d2-8041-4b37-8ed0-9e8c10869340"
resp = [r]

# Mock response object
mock_response = MagicMock()
mock_response.data.items = resp
mock_response.has_next_page = False
mock_response.next_page = None

# Mock object storage client
mock_object_storage_client = MagicMock()
mock_object_storage_client.list_object_versions.return_value = mock_response

mock_oci_client_factory.return_value.object_storage = mock_object_storage_client

# self.mock_dsc_model
self.mock_dsc_model.add_artifact(uri="oci://bucket@namespace/prefix")
expected_out = {
'version': '1.0',
'type': 'modelOSSReferenceDescription',
'models': [
{
'namespace': 'namespace',
'bucketName': 'bucket',
'prefix': 'prefix',
'objects': [
{
'name': 'model-linear-2/linear-2.pkl',
'version': 'dc26a7d2-8041-4b37-8ed0-9e8c10869340',
'sizeInBytes': 566
}
]
}
]
}
assert self.mock_dsc_model.model_file_description == expected_out
self.mock_dsc_model.remove_artifact(uri="oci://bucket@namespace/prefix")
assert self.mock_dsc_model.model_file_description != expected_out
expected_out = {
'version': '1.0',
'type': 'modelOSSReferenceDescription',
'models': []
}
assert self.mock_dsc_model.model_file_description == expected_out

def test_remove_artifact(self):
self.mock_dsc_model.remove_artifact(uri="oci://unzip-multi-model@ociodscdev/model-linear-1")
assert self.mock_dsc_model.model_file_description == None

self.mock_dsc_model.set_spec(CONST_MODEL_FILE_DESCRIPTION, deepcopy(MODEL_BY_REF_JSON))
assert self.mock_dsc_model.model_file_description == MODEL_BY_REF_JSON

self.mock_dsc_model.remove_artifact(uri="oci://unzip-multi-model@ociodscdev/model-linear-1")
assert self.mock_dsc_model.model_file_description != MODEL_BY_REF_JSON

exptected_json = deepcopy(MODEL_BY_REF_JSON)
exptected_json["models"] = exptected_json["models"][1:]
assert self.mock_dsc_model.model_file_description == exptected_json
Loading