Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket changes for AQUA (1.0.3) #892

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,4 @@
repos:
# ruff
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ADS has switched from black to ruff

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes , some of the files have outdated changes. will resolve these.

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.9
hooks:
- id: ruff
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
files: ^ads
exclude: ^docs/
- id: ruff-format
types_or: [ python, pyi, jupyter ]
exclude: ^docs/
# Standard hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
Expand All @@ -32,6 +20,12 @@ repos:
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
exclude: ^docs/
# Black, the code formatter, natively supports pre-commit
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
exclude: ^docs/
# Regex based rst files common mistakes detector
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
Expand Down
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ clean:
@find ./ -name 'Thumbs.db' -exec rm -f {} \;
@find ./ -name '*~' -exec rm -f {} \;
@find ./ -name '.DS_Store' -exec rm -f {} \;
test:
pip install -e .
jupyter server extension enable --py ads.aqua.extension
jupyter lab --NotebookApp.disable_check_xsrf=True --no-browser
install:
pip install -e .
jupyter server extension enable --py ads.aqua.extension
568 changes: 567 additions & 1 deletion ads/aqua/dummy_data/oci_models.json

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions ads/aqua/extension/deployment_ws_msg_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import List, Union

from ads import logger
from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
from ads.aqua.extension.models.ws_models import RequestResponseType, ListDeploymentResponse, ListDeploymentRequest
from ads.aqua.modeldeployment import AquaDeploymentApp
from ads.config import COMPARTMENT_OCID


class AquaDeploymentWSMsgHandler(AquaWSMsgHandler):

def __init__(self, message: Union[str, bytes]):
super().__init__(message)

@staticmethod
def get_message_types() -> List[RequestResponseType]:
return [RequestResponseType.ListDeployments]

@handle_exceptions
def process(self) -> ListDeploymentResponse:
list_deployment_request = ListDeploymentRequest.from_json(self.message)
deployment_list = AquaDeploymentApp().list(
compartment_id=list_deployment_request.compartment_id or COMPARTMENT_OCID,
project_id=list_deployment_request.project_id,
)
response = ListDeploymentResponse(
message_id=list_deployment_request.message_id,
kind=RequestResponseType.ListDeployments,
data=deployment_list,
)
return response
17 changes: 17 additions & 0 deletions ads/aqua/extension/models/ws_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@

from ads.aqua.evaluation.entities import AquaEvaluationSummary
from ads.aqua.model.entities import AquaModelSummary
from ads.aqua.modeldeployment.entities import AquaDeployment
from ads.common.extended_enum import ExtendedEnumMeta
from ads.common.serializer import DataClassSerializable


class RequestResponseType(str, metaclass=ExtendedEnumMeta):
ListEvaluations = "ListEvaluations"
ListDeployments = "ListDeployments"
ListModels = "ListModels"
Error = "Error"

Expand Down Expand Up @@ -43,13 +45,28 @@ class ListEvaluationsRequest(BaseRequest):
@dataclass
class ListModelsRequest(BaseRequest):
compartment_id: Optional[str] = None
project_id: Optional[str] = None
model_type: Optional[str] = None
kind = RequestResponseType.ListDeployments


@dataclass
class ListEvaluationsResponse(BaseResponse):
data: List[AquaEvaluationSummary]


@dataclass
class ListDeploymentRequest(BaseRequest):
compartment_id: str
project_id: Optional[str] = None
kind = RequestResponseType.ListDeployments


@dataclass
class ListDeploymentResponse(BaseResponse):
data: List[AquaDeployment]


@dataclass
class ListModelsResponse(BaseResponse):
data: List[AquaModelSummary]
Expand Down
33 changes: 33 additions & 0 deletions ads/aqua/extension/models_ws_msg_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import List, Union

from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
from ads.aqua.extension.models.ws_models import RequestResponseType,ListModelsResponse, ListModelsRequest
from ads.aqua.model import AquaModelApp
from ads.config import COMPARTMENT_OCID


class AquaModelWSMsgHandler(AquaWSMsgHandler):

def __init__(self, message: Union[str, bytes]):
super().__init__(message)

@staticmethod
def get_message_types() -> List[RequestResponseType]:
return [RequestResponseType.ListModels]

@handle_exceptions
def process(self) -> ListModelsResponse:
list_models_request = ListModelsRequest.from_json(self.message)
print(list_models_request)
models_list = AquaModelApp().list(
compartment_id=list_models_request.compartment_id or COMPARTMENT_OCID,
project_id=list_models_request.project_id,
model_type=list_models_request.model_type
)
response = ListModelsResponse(
message_id=list_models_request.message_id,
kind=RequestResponseType.ListModels,
data=models_list,
)
return response
4 changes: 3 additions & 1 deletion ads/aqua/extension/ui_websocket_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ads.aqua import logger
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
from ads.aqua.extension.deployment_ws_msg_handler import AquaDeploymentWSMsgHandler
from ads.aqua.extension.evaluation_ws_msg_handler import AquaEvaluationWSMsgHandler
from ads.aqua.extension.models.ws_models import (
AquaWsError,
Expand All @@ -22,6 +23,7 @@
ErrorResponse,
RequestResponseType,
)
from ads.aqua.extension.models_ws_msg_handler import AquaModelWSMsgHandler

MAX_WORKERS = 20

Expand All @@ -43,7 +45,7 @@ def get_aqua_internal_error_response(message_id: str) -> ErrorResponse:
class AquaUIWebSocketHandler(WebSocketHandler):
"""Handler for Aqua Websocket."""

_handlers_: List[Type[AquaWSMsgHandler]] = [AquaEvaluationWSMsgHandler]
_handlers_: List[Type[AquaWSMsgHandler]] = [AquaEvaluationWSMsgHandler,AquaDeploymentWSMsgHandler,AquaModelWSMsgHandler]

thread_pool: ThreadPoolExecutor

Expand Down
1 change: 1 addition & 0 deletions ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
List[AquaDeployment]:
The list of the Aqua model deployments.
"""
print("kwargs: ",kwargs)
compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)

model_deployments = self.list_resource(
Expand Down
20 changes: 9 additions & 11 deletions ads/common/oci_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,14 +614,13 @@ def tail(
A list of log records.
Each log record is a dictionary with the following keys: `id`, `time`, `message`.
"""
tail_logs = self._search_and_format(
return self._search_and_format(
source=source,
limit=limit,
sort_order=SortOrder.DESC,
sort_order=SortOrder.ASC,
time_start=time_start,
log_filter=log_filter,
)
return sorted(tail_logs, key=lambda log: log["time"])

def head(
self,
Expand Down Expand Up @@ -855,15 +854,14 @@ def tail(
Expression for filtering the logs. This will be the WHERE clause of the query.
Defaults to None.
"""
tail_logs = self._search_and_format(
source=source,
limit=limit,
sort_order=SortOrder.DESC,
time_start=time_start,
log_filter=log_filter,
)
self._print(
sorted(tail_logs, key=lambda log: log["time"])
self._search_and_format(
source=source,
limit=limit,
sort_order=SortOrder.ASC,
time_start=time_start,
log_filter=log_filter,
)
)

def head(
Expand Down
6 changes: 0 additions & 6 deletions ads/llm/langchain/plugins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from langchain.llms.base import LLM
from langchain.pydantic_v1 import BaseModel, Field, root_validator

from ads import logger
from ads.common.auth import default_signer
from ads.config import COMPARTMENT_OCID

Expand Down Expand Up @@ -96,11 +95,6 @@ def validate_environment( # pylint: disable=no-self-argument
"""Validate that python package exists in environment."""
# Initialize client only if user does not pass in client.
# Users may choose to initialize the OCI client by themselves and pass it into this model.
logger.warning(
f"The ads langchain plugin {cls.__name__} will be deprecated soon. "
"Please refer to https://python.langchain.com/v0.2/docs/integrations/providers/oci/ "
"for the latest support."
)
if not values.get("client"):
auth = values.get("auth", {})
client_kwargs = auth.get("client_kwargs") or {}
Expand Down
Loading
Loading