Skip to content

Commit

Permalink
Added add_artifact and remove_artifact method in main DataScienceMode…
Browse files Browse the repository at this point in the history
…l class itself.
  • Loading branch information
YashPandit4u committed Jun 11, 2024
1 parent 3d4d950 commit d4ef023
Showing 1 changed file with 158 additions and 0 deletions.
158 changes: 158 additions & 0 deletions ads/model/datascience_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
ModelProvenanceNotFoundError,
OCIDataScienceModel,
)
from ads.common import oci_client as oc
from ads.common.auth import default_signer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1466,3 +1468,159 @@ def _download_file_description_artifact(self) -> Tuple[Union[str, List[str]], in
bucket_uri.append(uri)

return bucket_uri[0] if len(bucket_uri) == 1 else bucket_uri, artifact_size

def add_artifact(
self,
namespace: str,
bucket: str,
prefix: Optional[str] = None,
files: Optional[List[str]] = None,
):
"""
Adds information about objects in a specified bucket to the model description JSON.
Parameters:
- namespace (str): The namespace of the object storage.
- bucket (str): The name of the bucket containing the objects.
- prefix (str, optional): The prefix used to filter objects within the bucket. Defaults to None.
- files (list of str, optional): A list of file names to include in the model description.
If provided, only objects with matching file names will be included. Defaults to None.
Returns:
- None
Raises:
- ValueError: If no files are found to add to the model description.
Note:
- If `files` is not provided, it retrieves information about all objects in the bucket.
If `files` is provided, it only retrieves information about objects with matching file names.
- If no objects are found to add to the model description, a ValueError is raised.
"""
if self.model_file_description == None:
self.empty_json = {
"version": "1.0",
"type": "modelOSSReferenceDescription",
"models": [],
}
self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, self.empty_json)

# Get object storage client
authData = default_signer()
self.object_storage_client = oc.OCIClientFactory(**authData).object_storage

# Remove if the model already exists
self.remove_artifact(namespace, bucket, prefix)

def check_if_file_exists(fileName):
isExists = False
try:
headResponse = self.object_storage_client.head_object(
namespace, bucket, object_name=fileName
)
if headResponse.status == 200:
isExists = True
except Exception as e:
if hasattr(e, "status") and e.status == 404:
logger.error(f"File not found in bucket: {fileName}")
else:
logger.error(f"An error occured: {e}")
return isExists

# Function to un-paginate the api call with while loop
def list_obj_versions_unpaginated():
objectStorageList = []
has_next_page, opc_next_page = True, None
while has_next_page:
response = self.object_storage_client.list_object_versions(
namespace_name=namespace,
bucket_name=bucket,
prefix=prefix,
fields="name,size",
page=opc_next_page,
)
objectStorageList.extend(response.data.items)
has_next_page = response.has_next_page
opc_next_page = response.next_page
return objectStorageList

# Fetch object details and put it into the objects variable
objectStorageList = []
if files == None:
objectStorageList = list_obj_versions_unpaginated()
else:
for fileName in files:
if check_if_file_exists(fileName=fileName):
objectStorageList.append(
self.object_storage_client.list_object_versions(
namespace_name=namespace,
bucket_name=bucket,
prefix=fileName,
fields="name,size",
).data.items[0]
)

objects = [
{"name": obj.name, "version": obj.version_id, "sizeInBytes": obj.size}
for obj in objectStorageList
if obj.size > 0
]

if len(objects) == 0:
error_message = (
f"No files to add in the bucket: {bucket} with namespace: {namespace} "
f"and prefix: {prefix}. File names: {files}"
)
logger.error(error_message)
raise ValueError(error_message)

tmp_model_file_description = self.model_file_description
tmp_model_file_description["models"].append(
{
"namespace": namespace,
"bucketName": bucket,
"prefix": prefix,
"objects": objects,
}
)
self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, tmp_model_file_description)

def remove_artifact(self, namespace: str, bucket: str, prefix: Optional[str] = None):
"""
Removes information about objects in a specified bucket from the model description JSON.
Parameters:
- namespace (str): The namespace of the object storage.
- bucket (str): The name of the bucket containing the objects.
- prefix (str, optional): The prefix used to filter objects within the bucket. Defaults to None.
Returns:
- None
Note:
- This method removes information about objects in the specified bucket from the
instance of the ModelDescription.
- If a matching model (with the specified namespace, bucket name, and prefix) is found
in the model description JSON, it is removed.
- If no matching model is found, the method returns without making any changes.
"""

def findModelIdx():
for idx, model in enumerate(self.model_file_description["models"]):
if (
model["namespace"],
model["bucketName"],
(model["prefix"] if ("prefix" in model) else None),
) == (namespace, bucket, prefix):
return idx
return -1

if self.model_file_description == None:
return

modelSearchIdx = findModelIdx()
if modelSearchIdx == -1:
return
else:
# model found case
self.model_file_description["models"].pop(modelSearchIdx)

0 comments on commit d4ef023

Please sign in to comment.