Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
99c4ae4
add hf apis
VipulMascarenhas Jul 25, 2024
043cf4c
adding artifact location in get model response
kumar-shivam-ranjan Jul 25, 2024
1d10623
[HF integration] GET model changes (#914)
VipulMascarenhas Jul 25, 2024
74349a8
review comments
VipulMascarenhas Jul 25, 2024
a57f7ab
Merge branch 'ODSC-60405/hf-integration' of github.com:oracle/acceler…
VipulMascarenhas Jul 25, 2024
2def6de
Add HF integration APIs (#913)
VipulMascarenhas Jul 25, 2024
7e5f400
os model registration validation
VipulMascarenhas Jul 26, 2024
1af55a3
update validation check
VipulMascarenhas Jul 26, 2024
656e6e1
update validation to use file formats
VipulMascarenhas Jul 29, 2024
95f2959
Added logic to download model from hugginface.
lu-ohai Jul 29, 2024
e0d9183
Merge branch 'ODSC-60499/register-os-model-validation' of https://git…
lu-ohai Jul 29, 2024
b939914
Updated pr.
lu-ohai Jul 29, 2024
28c8082
Updated pr.
lu-ohai Jul 29, 2024
31575c1
Added Logic to Download Model from HugginFace (#917)
lu-ohai Jul 29, 2024
87ee74f
support model_file during register and deploy
VipulMascarenhas Jul 31, 2024
11d8880
add hf files api
VipulMascarenhas Jul 31, 2024
eea53c2
[ODSC-60499] Update validation for models registered from object stor…
VipulMascarenhas Jul 31, 2024
0f37bc8
update tags and add tests
VipulMascarenhas Aug 2, 2024
8981dd3
remove hf token auth
VipulMascarenhas Aug 2, 2024
1f7e9c5
add license to tags
VipulMascarenhas Aug 2, 2024
8069c6f
Fixed downloading huggingface model issue
lu-ohai Aug 7, 2024
3740ed6
Fixed downloading huggingface model issue (#925)
lu-ohai Aug 7, 2024
2cdb5e3
Fixed fine tuning issue while both types of model files are present.
lu-ohai Aug 8, 2024
229fba0
Fixed fine tuning issue while both types of model files are present. …
lu-ohai Aug 8, 2024
0cfc2a4
Update evaluation params logic
harsh97 Aug 9, 2024
f4a906e
deployment changes
harsh97 Aug 9, 2024
f0e2e22
Update evaluation params logic (#927)
harsh97 Aug 9, 2024
5215a9d
Update release_notes.rst
mayoor Aug 9, 2024
b74fd3b
Update pyproject.toml
mayoor Aug 9, 2024
ed2c854
Merge branch 'main' into ODSC-60499/additional_tests
mayoor Aug 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions ads/aqua/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

Expand Down Expand Up @@ -175,7 +174,7 @@ def create_model_version_set(
f"Invalid model version set name. Please provide a model version set with `{tag}` in tags."
)

except:
except Exception:
logger.debug(
f"Model version set {model_version_set_name} doesn't exist. "
"Creating new model version set."
Expand Down Expand Up @@ -254,7 +253,7 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:

try:
response = self.ds_client.head_model_artifact(model_id=model_id, **kwargs)
return True if response.status == 200 else False
return response.status == 200
except oci.exceptions.ServiceError as ex:
if ex.status == 404:
logger.info(f"Artifact not found in model {model_id}.")
Expand Down Expand Up @@ -302,15 +301,15 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
config_path,
config_file_name=config_file_name,
)
except:
except Exception:
# todo: temp fix for issue related to config load for byom models, update logic to choose the right path
try:
config_path = f"{artifact_path.rstrip('/')}/config/"
config = load_config(
config_path,
config_file_name=config_file_name,
)
except:
except Exception:
pass

if not config:
Expand Down Expand Up @@ -343,7 +342,7 @@ def build_cli(self) -> str:
params = [
f"--{field.name} {getattr(self,field.name)}"
for field in fields(self.__class__)
if getattr(self, field.name)
if getattr(self, field.name) is not None
]
cmd = f"{cmd} {' '.join(params)}"
return cmd
9 changes: 9 additions & 0 deletions ads/aqua/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Tags(str, metaclass=ExtendedEnumMeta):
BASE_MODEL_CUSTOM = "aqua_custom_base_model"
AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
MODEL_FORMAT = "model_format"
MODEL_ARTIFACT_FILE = "model_file"


class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
Expand All @@ -59,6 +60,14 @@ class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
PARAM_TYPE_LLAMA_CPP = "LLAMA_CPP_PARAMS"


class EvaluationContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
AQUA_EVALUATION_CONTAINER_FAMILY = "odsc-llm-evaluate"


class FineTuningContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
AQUA_FINETUNING_CONTAINER_FAMILY = "odsc-llm-fine-tuning"


class HuggingFaceTags(str, metaclass=ExtendedEnumMeta):
TEXT_GENERATION_INFERENCE = "text-generation-inference"

Expand Down
129 changes: 128 additions & 1 deletion ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import os
import random
import re
import shlex
import subprocess
from datetime import datetime, timedelta
from functools import wraps
from pathlib import Path
Expand All @@ -19,6 +21,13 @@
import fsspec
import oci
from cachetools import TTLCache, cached
from huggingface_hub.hf_api import HfApi, ModelInfo
from huggingface_hub.utils import (
GatedRepoError,
HfHubHTTPError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from oci.data_science.models import JobRun, Model
from oci.object_storage.models import ObjectSummary

Expand All @@ -37,6 +46,7 @@
COMPARTMENT_MAPPING_KEY,
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
CONTAINER_INDEX,
HF_LOGIN_DEFAULT_TIMEOUT,
MAXIMUM_ALLOWED_DATASET_IN_BYTE,
MODEL_BY_REFERENCE_OSS_PATH_KEY,
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
Expand All @@ -47,7 +57,7 @@
VLLM_INFERENCE_RESTRICTED_PARAMS,
)
from ads.aqua.data import AquaResourceIdentifier
from ads.common.auth import default_signer
from ads.common.auth import AuthState, default_signer
from ads.common.extended_enum import ExtendedEnumMeta
from ads.common.object_storage_details import ObjectStorageDetails
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
Expand Down Expand Up @@ -771,6 +781,33 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""


def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
"""Upload the local folder to the object storage

Args:
os_path (str): object storage URI with prefix. This is the path to upload
local_dir (str): Local directory where the object is downloaded
model_name (str): Name of the huggingface model
Retuns:
str: Object name inside the bucket
"""
os_details: ObjectStorageDetails = ObjectStorageDetails.from_path(os_path)
if not os_details.is_bucket_versioned():
raise ValueError(f"Version is not enabled at object storage location {os_path}")
auth_state = AuthState()
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
try:
logger.info(f"Running: {command}")
subprocess.check_call(shlex.split(command))
except subprocess.CalledProcessError as e:
logger.error(
f"Error uploading the object. Exit code: {e.returncode} with error {e.stdout}"
)

return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path


def is_service_managed_container(container):
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)

Expand Down Expand Up @@ -935,3 +972,93 @@ def get_restricted_params_by_container(container_type_name: str) -> set:
return TGI_INFERENCE_RESTRICTED_PARAMS
else:
return set()


def get_huggingface_login_timeout() -> int:
"""This helper function returns the huggingface login timeout, returns default if not set via
env var.
Returns
-------
timeout: int
huggingface login timeout.

"""
timeout = HF_LOGIN_DEFAULT_TIMEOUT
try:
timeout = int(
os.environ.get("HF_LOGIN_DEFAULT_TIMEOUT", HF_LOGIN_DEFAULT_TIMEOUT)
)
except ValueError:
pass
return timeout


def format_hf_custom_error_message(error: HfHubHTTPError):
"""
Formats a custom error message based on the Hugging Face error response.

Parameters
----------
error (HfHubHTTPError): The caught exception.

Raises
------
AquaRuntimeError: A user-friendly error message.
"""
# Extract the repository URL from the error message if present
match = re.search(r"(https://huggingface.co/[^\s]+)", str(error))
url = match.group(1) if match else "the requested Hugging Face URL."

if isinstance(error, RepositoryNotFoundError):
raise AquaRuntimeError(
reason=f"Failed to access `{url}`. Please check if the provided repository name is correct. "
"If the repo is private, make sure you are authenticated and have a valid HF token registered. "
"To register your token, run this command in your terminal: `huggingface-cli login`",
service_payload={"error": "RepositoryNotFoundError"},
)

if isinstance(error, GatedRepoError):
raise AquaRuntimeError(
reason=f"Access denied to `{url}` "
"This repository is gated. Access is restricted to authorized users. "
"Please request access or check with the repository administrator. "
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
"To register your token, run this command in your terminal: `huggingface-cli login`",
service_payload={"error": "GatedRepoError"},
)

if isinstance(error, RevisionNotFoundError):
raise AquaRuntimeError(
reason=f"The specified revision could not be found at `{url}` "
"Please check the revision identifier and try again.",
service_payload={"error": "RevisionNotFoundError"},
)

raise AquaRuntimeError(
reason=f"An error occurred while accessing `{url}` "
"Please check your network connection and try again. "
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
"To register your token, run this command in your terminal: `huggingface-cli login`",
service_payload={"error": "Error"},
)


@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
def get_hf_model_info(repo_id: str) -> ModelInfo:
"""Gets the model information object for the given model repository name. For models that requires a token,
this method assumes that the token validation is already done.

Parameters
----------
repo_id: str
hugging face model repository name

Returns
-------
instance of ModelInfo object

"""
try:
return HfApi().model_info(repo_id=repo_id)
except HfHubHTTPError as err:
raise format_hf_custom_error_message(err) from err
1 change: 1 addition & 0 deletions ads/aqua/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
AQUA_MODEL_ARTIFACT_FILE = "model_file"
HF_LOGIN_DEFAULT_TIMEOUT = 2

TRAINING_METRICS_FINAL = "training_metrics_final"
VALIDATION_METRICS_FINAL = "validation_metrics_final"
Expand Down
2 changes: 1 addition & 1 deletion ads/aqua/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def create(
enable_spec=True
).inference
for container in inference_config.values():
if container.name == runtime.image.split(":")[0]:
if container.name == runtime.image[:runtime.image.rfind(":")]:
eval_inference_configuration = (
container.spec.evaluation_configuration
)
Expand Down
80 changes: 75 additions & 5 deletions ads/aqua/extension/common_handler.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


from importlib import metadata

import huggingface_hub
import requests
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from tornado.web import HTTPError

from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.common.errors import AquaResourceAccessError, AquaRuntimeError
from ads.aqua.common.utils import fetch_service_compartment, known_realm
from ads.aqua.common.utils import (
fetch_service_compartment,
get_huggingface_login_timeout,
known_realm,
)
from ads.aqua.extension.base_handler import AquaAPIhandler
from ads.aqua.extension.errors import Errors

Expand Down Expand Up @@ -46,16 +52,80 @@ def get(self):

"""
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
return self.finish(dict(status="ok"))
return self.finish({"status": "ok"})
elif known_realm():
return self.finish(dict(status="compatible"))
return self.finish({"status": "compatible"})
else:
raise AquaResourceAccessError(
f"The AI Quick actions extension is not compatible in the given region."
"The AI Quick actions extension is not compatible in the given region."
)


class NetworkStatusHandler(AquaAPIhandler):
"""Handler to check internet connection."""

@handle_exceptions
def get(self):
requests.get("https://huggingface.com", timeout=get_huggingface_login_timeout())
return self.finish({"status": 200, "message": "success"})


class HFLoginHandler(AquaAPIhandler):
"""Handler to login to HF."""

@handle_exceptions
def post(self, *args, **kwargs):
"""Handles post request for the HF login.

Raises
------
HTTPError
Raises HTTPError if inputs are missing or are invalid.
"""
try:
input_data = self.get_json_body()
except Exception as ex:
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex

if not input_data:
raise HTTPError(400, Errors.NO_INPUT_DATA)

token = input_data.get("token")

if not token:
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("token"))

# Login to HF
try:
huggingface_hub.login(token=token, new_session=False)
except Exception as ex:
raise AquaRuntimeError(
reason=str(ex), service_payload={"error": type(ex).__name__}
) from ex

return self.finish({"status": 200, "message": "login successful"})


class HFUserStatusHandler(AquaAPIhandler):
"""Handler to check if user logged in to the HF."""

@handle_exceptions
def get(self):
try:
HfApi().whoami()
except LocalTokenNotFoundError as err:
raise AquaRuntimeError(
"You are not logged in. Please log in to Hugging Face using the `huggingface-cli login` command."
"See https://huggingface.co/settings/tokens.",
) from err

return self.finish({"status": 200, "message": "logged in"})


__handlers__ = [
("ads_version", ADSVersionHandler),
("hello", CompatibilityCheckHandler),
("network_status", NetworkStatusHandler),
("hf_login", HFLoginHandler),
("hf_logged_in", HFUserStatusHandler),
]
2 changes: 2 additions & 0 deletions ads/aqua/extension/deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def post(self, *args, **kwargs):
container_family = input_data.get("container_family")
ocpus = input_data.get("ocpus")
memory_in_gbs = input_data.get("memory_in_gbs")
model_file = input_data.get("model_file")

self.finish(
AquaDeploymentApp().create(
Expand All @@ -122,6 +123,7 @@ def post(self, *args, **kwargs):
container_family=container_family,
ocpus=ocpus,
memory_in_gbs=memory_in_gbs,
model_file=model_file,
)
)

Expand Down
Loading
Loading