Skip to content

Commit

Permalink
fix(client): using httpx for running calls within async context
Browse files Browse the repository at this point in the history
This is so that client.query works within a async context

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Jun 15, 2023
1 parent b3d924e commit 528f76e
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 7 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Expand Up @@ -39,6 +39,9 @@ dependencies = [
# tabulate for CLI with CJK support
# >=0.9.0 for some bug fixes
"tabulate[widechars]>=0.9.0",
# httpx used within openllm.client
"httpx",
# for typing support
"typing_extensions",
]
description = 'OpenLLM: REST/gRPC API server for running any open Large-Language Model - StableLM, Llama, Alpaca, Dolly, Flan-T5, Custom'
Expand Down
4 changes: 2 additions & 2 deletions src/openllm/_configuration.py
Expand Up @@ -1026,11 +1026,11 @@ def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, DictStrAny]:
return self.model_construct_env(**llm_config_attrs), {k: v for k, v in attrs.items() if k not in key_to_remove}

@t.overload
def to_generation_config(self, return_as_dict: t.Literal[True] = ...) -> DictStrAny:
def to_generation_config(self, return_as_dict: t.Literal[False] = ...) -> transformers.GenerationConfig:
...

@t.overload
def to_generation_config(self, return_as_dict: t.Literal[False] = ...) -> transformers.GenerationConfig:
def to_generation_config(self, return_as_dict: t.Literal[True] = ...) -> DictStrAny:
...

def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | DictStrAny:
Expand Down
4 changes: 2 additions & 2 deletions src/openllm/utils/dantic.py
Expand Up @@ -280,7 +280,7 @@ def __init__(self, enum: Enum, case_sensitive: bool = False):
self.internal_type = enum
super().__init__([e.name for e in self.mapping], case_sensitive)

def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
if isinstance(value, self.internal_type):
return value
result = super().convert(value, param, ctx)
Expand All @@ -292,7 +292,7 @@ def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Contex
class LiteralChoice(EnumChoice):
name = "literal"

def __init__(self, enum: t.Literal, case_sensitive: bool = False):
def __init__(self, enum: t.LiteralString, case_sensitive: bool = False):
# expect every literal value to belong to the same primitive type
values = list(enum.__args__)
item_type = type(values[0])
Expand Down
27 changes: 25 additions & 2 deletions src/openllm_client/runtimes/base.py
Expand Up @@ -14,10 +14,13 @@

from __future__ import annotations

import asyncio
import typing as t
from abc import abstractmethod
from urllib.parse import urljoin

import bentoml
import httpx

import openllm

Expand All @@ -43,6 +46,14 @@ def metadata_v1(self) -> dict[str, t.Any]:
...


def in_async_context() -> bool:
try:
_ = asyncio.get_running_loop()
return True
except RuntimeError:
return False


class ClientMixin:
_api_version: str
_client_class: type[bentoml.client.Client]
Expand All @@ -57,12 +68,17 @@ def __init__(self, address: str, timeout: int = 30):
self._address = address
self._timeout = timeout
assert self._host and self._port, "Make sure to setup _host and _port based on your client implementation."
self._metadata = self.call("metadata")

def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"):
cls._client_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient
cls._api_version = api_version

@property
def _metadata(self) -> dict[str, t.Any]:
if in_async_context():
return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json()
return self.call("metadata")

@property
@abstractmethod
def model_name(self) -> str:
Expand Down Expand Up @@ -140,7 +156,14 @@ def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **at
def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str:
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(prompt, **attrs)
inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
result = self.call("generate", inputs)
if in_async_context():
result = httpx.post(
urljoin(self._address, f"/{self._api_version}/generate"),
json=openllm.utils.bentoml_cattr.unstructure(inputs),
timeout=self.timeout,
).json()
else:
result = self.call("generate", inputs)
r = self.postprocess(result)

if return_raw_response:
Expand Down
5 changes: 4 additions & 1 deletion src/openllm_client/runtimes/grpc.py
Expand Up @@ -73,7 +73,10 @@ def configuration(self) -> dict[str, t.Any]:
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")

def postprocess(self, result: Response) -> openllm.GenerationOutput:
def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput:
if isinstance(result, dict):
return openllm.GenerationOutput(**result)

from google.protobuf.json_format import MessageToDict

return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))
Expand Down

0 comments on commit 528f76e

Please sign in to comment.