From 528f76e1d071fa576d8654b624b63b60c5964252 Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Thu, 15 Jun 2023 01:58:49 -0400 Subject: [PATCH] fix(client): using httpx for running calls within async context This is so that client.query works within a async context Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- pyproject.toml | 3 +++ src/openllm/_configuration.py | 4 ++-- src/openllm/utils/dantic.py | 4 ++-- src/openllm_client/runtimes/base.py | 27 +++++++++++++++++++++++++-- src/openllm_client/runtimes/grpc.py | 5 ++++- 5 files changed, 36 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0e1232542..5df1b4b31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,9 @@ dependencies = [ # tabulate for CLI with CJK support # >=0.9.0 for some bug fixes "tabulate[widechars]>=0.9.0", + # httpx used within openllm.client + "httpx", + # for typing support "typing_extensions", ] description = 'OpenLLM: REST/gRPC API server for running any open Large-Language Model - StableLM, Llama, Alpaca, Dolly, Flan-T5, Custom' diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index a51f62426..6b5dfbbc4 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -1026,11 +1026,11 @@ def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, DictStrAny]: return self.model_construct_env(**llm_config_attrs), {k: v for k, v in attrs.items() if k not in key_to_remove} @t.overload - def to_generation_config(self, return_as_dict: t.Literal[True] = ...) -> DictStrAny: + def to_generation_config(self, return_as_dict: t.Literal[False] = ...) -> transformers.GenerationConfig: ... @t.overload - def to_generation_config(self, return_as_dict: t.Literal[False] = ...) -> transformers.GenerationConfig: + def to_generation_config(self, return_as_dict: t.Literal[True] = ...) -> DictStrAny: ... def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | DictStrAny: diff --git a/src/openllm/utils/dantic.py b/src/openllm/utils/dantic.py index 74f0bff0a..ac598bb82 100644 --- a/src/openllm/utils/dantic.py +++ b/src/openllm/utils/dantic.py @@ -280,7 +280,7 @@ def __init__(self, enum: Enum, case_sensitive: bool = False): self.internal_type = enum super().__init__([e.name for e in self.mapping], case_sensitive) - def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any: + def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum: if isinstance(value, self.internal_type): return value result = super().convert(value, param, ctx) @@ -292,7 +292,7 @@ def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Contex class LiteralChoice(EnumChoice): name = "literal" - def __init__(self, enum: t.Literal, case_sensitive: bool = False): + def __init__(self, enum: t.LiteralString, case_sensitive: bool = False): # expect every literal value to belong to the same primitive type values = list(enum.__args__) item_type = type(values[0]) diff --git a/src/openllm_client/runtimes/base.py b/src/openllm_client/runtimes/base.py index be98df297..fe019daca 100644 --- a/src/openllm_client/runtimes/base.py +++ b/src/openllm_client/runtimes/base.py @@ -14,10 +14,13 @@ from __future__ import annotations +import asyncio import typing as t from abc import abstractmethod +from urllib.parse import urljoin import bentoml +import httpx import openllm @@ -43,6 +46,14 @@ def metadata_v1(self) -> dict[str, t.Any]: ... +def in_async_context() -> bool: + try: + _ = asyncio.get_running_loop() + return True + except RuntimeError: + return False + + class ClientMixin: _api_version: str _client_class: type[bentoml.client.Client] @@ -57,12 +68,17 @@ def __init__(self, address: str, timeout: int = 30): self._address = address self._timeout = timeout assert self._host and self._port, "Make sure to setup _host and _port based on your client implementation." - self._metadata = self.call("metadata") def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"): cls._client_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient cls._api_version = api_version + @property + def _metadata(self) -> dict[str, t.Any]: + if in_async_context(): + return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json() + return self.call("metadata") + @property @abstractmethod def model_name(self) -> str: @@ -140,7 +156,14 @@ def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **at def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str: return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(prompt, **attrs) inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)) - result = self.call("generate", inputs) + if in_async_context(): + result = httpx.post( + urljoin(self._address, f"/{self._api_version}/generate"), + json=openllm.utils.bentoml_cattr.unstructure(inputs), + timeout=self.timeout, + ).json() + else: + result = self.call("generate", inputs) r = self.postprocess(result) if return_raw_response: diff --git a/src/openllm_client/runtimes/grpc.py b/src/openllm_client/runtimes/grpc.py index dfea9356b..c95347a1b 100644 --- a/src/openllm_client/runtimes/grpc.py +++ b/src/openllm_client/runtimes/grpc.py @@ -73,7 +73,10 @@ def configuration(self) -> dict[str, t.Any]: except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") - def postprocess(self, result: Response) -> openllm.GenerationOutput: + def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput: + if isinstance(result, dict): + return openllm.GenerationOutput(**result) + from google.protobuf.json_format import MessageToDict return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))