From 6a0bb46421088716b86284dc8f896a0893bf32fc Mon Sep 17 00:00:00 2001 From: Lev Vereshchagin Date: Fri, 22 Nov 2024 16:30:16 +0300 Subject: [PATCH 1/4] Allow extra params in request body --- any_llm_client/clients/mock.py | 10 +++++++++- any_llm_client/clients/openai.py | 9 +++++++-- any_llm_client/clients/yandexgpt.py | 24 ++++++++++++++++++------ any_llm_client/core.py | 4 ++-- 4 files changed, 36 insertions(+), 11 deletions(-) diff --git a/any_llm_client/clients/mock.py b/any_llm_client/clients/mock.py index 70a1e52..7e07302 100644 --- a/any_llm_client/clients/mock.py +++ b/any_llm_client/clients/mock.py @@ -19,7 +19,13 @@ class MockLLMConfig(LLMConfig): class MockLLMClient(LLMClient): config: MockLLMConfig - async def request_llm_message(self, messages: str | list[Message], temperature: float = 0.2) -> str: # noqa: ARG002 + async def request_llm_message( + self, + messages: str | list[Message], # noqa: ARG002 + *, + temperature: float = 0.2, # noqa: ARG002 + extra: dict[str, typing.Any] | None = None, # noqa: ARG002 + ) -> str: return self.config.response_message async def _iter_config_stream_messages(self) -> typing.AsyncIterable[str]: @@ -30,7 +36,9 @@ async def _iter_config_stream_messages(self) -> typing.AsyncIterable[str]: async def stream_llm_partial_messages( self, messages: str | list[Message], # noqa: ARG002 + *, temperature: float = 0.2, # noqa: ARG002 + extra: dict[str, typing.Any] | None = None, # noqa: ARG002 ) -> typing.AsyncIterator[typing.AsyncIterable[str]]: yield self._iter_config_stream_messages() diff --git a/any_llm_client/clients/openai.py b/any_llm_client/clients/openai.py index 332b46c..65bb5d2 100644 --- a/any_llm_client/clients/openai.py +++ b/any_llm_client/clients/openai.py @@ -45,6 +45,7 @@ class ChatCompletionsMessage(pydantic.BaseModel): class ChatCompletionsRequest(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra="allow") stream: bool model: str messages: list[ChatCompletionsMessage] @@ -140,12 +141,15 @@ def _prepare_messages(self, messages: str | list[Message]) -> list[ChatCompletio else list(initial_messages) ) - async def request_llm_message(self, messages: str | list[Message], temperature: float = 0.2) -> str: + async def request_llm_message( + self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None + ) -> str: payload: typing.Final = ChatCompletionsRequest( stream=False, model=self.config.model_name, messages=self._prepare_messages(messages), temperature=temperature, + **extra or {}, ).model_dump(mode="json") try: response: typing.Final = await make_http_request( @@ -173,13 +177,14 @@ async def _iter_partial_responses(self, response: httpx.Response) -> typing.Asyn @contextlib.asynccontextmanager async def stream_llm_partial_messages( - self, messages: str | list[Message], temperature: float = 0.2 + self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None ) -> typing.AsyncIterator[typing.AsyncIterable[str]]: payload: typing.Final = ChatCompletionsRequest( stream=True, model=self.config.model_name, messages=self._prepare_messages(messages), temperature=temperature, + **extra or {}, ).model_dump(mode="json") try: async with make_streaming_http_request( diff --git a/any_llm_client/clients/yandexgpt.py b/any_llm_client/clients/yandexgpt.py index b684800..10c8818 100644 --- a/any_llm_client/clients/yandexgpt.py +++ b/any_llm_client/clients/yandexgpt.py @@ -43,7 +43,7 @@ class YandexGPTCompletionOptions(pydantic.BaseModel): class YandexGPTRequest(pydantic.BaseModel): - model_config = pydantic.ConfigDict(protected_namespaces=()) + model_config = pydantic.ConfigDict(protected_namespaces=(), extra="allow") model_uri: str = pydantic.Field(alias="modelUri") completion_options: YandexGPTCompletionOptions = pydantic.Field(alias="completionOptions") messages: list[Message] @@ -96,7 +96,12 @@ def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request: ) def _prepare_payload( - self, *, messages: str | list[Message], temperature: float = 0.2, stream: bool + self, + *, + messages: str | list[Message], + temperature: float = 0.2, + stream: bool, + extra: dict[str, typing.Any] | None, ) -> dict[str, typing.Any]: messages = [UserMessage(messages)] if isinstance(messages, str) else messages return YandexGPTRequest( @@ -105,10 +110,15 @@ def _prepare_payload( stream=stream, temperature=temperature, maxTokens=self.config.max_tokens ), messages=messages, + **extra or {}, ).model_dump(mode="json", by_alias=True) - async def request_llm_message(self, messages: str | list[Message], temperature: float = 0.2) -> str: - payload: typing.Final = self._prepare_payload(messages=messages, temperature=temperature, stream=False) + async def request_llm_message( + self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None + ) -> str: + payload: typing.Final = self._prepare_payload( + messages=messages, temperature=temperature, stream=False, extra=extra + ) try: response: typing.Final = await make_http_request( @@ -128,9 +138,11 @@ async def _iter_completion_messages(self, response: httpx.Response) -> typing.As @contextlib.asynccontextmanager async def stream_llm_partial_messages( - self, messages: str | list[Message], temperature: float = 0.2 + self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None ) -> typing.AsyncIterator[typing.AsyncIterable[str]]: - payload: typing.Final = self._prepare_payload(messages=messages, temperature=temperature, stream=True) + payload: typing.Final = self._prepare_payload( + messages=messages, temperature=temperature, stream=True, extra=extra + ) try: async with make_streaming_http_request( diff --git a/any_llm_client/core.py b/any_llm_client/core.py index b071ec9..e0ced36 100644 --- a/any_llm_client/core.py +++ b/any_llm_client/core.py @@ -68,12 +68,12 @@ class LLMConfig(pydantic.BaseModel): @dataclasses.dataclass(slots=True, init=False) class LLMClient(typing.Protocol): async def request_llm_message( - self, messages: str | list[Message], *, temperature: float = 0.2 + self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None ) -> str: ... # raises LLMError @contextlib.asynccontextmanager def stream_llm_partial_messages( - self, messages: str | list[Message], temperature: float = 0.2 + self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None ) -> typing.AsyncIterator[typing.AsyncIterable[str]]: ... # raises LLMError async def __aenter__(self) -> typing_extensions.Self: ... From e796ad7e4dd5e71995f64f30574a1eeb690d409b Mon Sep 17 00:00:00 2001 From: Lev Vereshchagin Date: Fri, 22 Nov 2024 16:40:27 +0300 Subject: [PATCH 2/4] Add tests --- tests/conftest.py | 1 + tests/test_static.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index f2bfb6d..c003365 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,7 @@ def _deactivate_retries() -> None: class LLMFuncRequest(typing.TypedDict): messages: str | list[any_llm_client.Message] temperature: float + extra: dict[str, typing.Any] | None class LLMFuncRequestFactory(TypedDictFactory[LLMFuncRequest]): ... diff --git a/tests/test_static.py b/tests/test_static.py index d5f2fed..4d73521 100644 --- a/tests/test_static.py +++ b/tests/test_static.py @@ -2,9 +2,14 @@ import typing import faker +import pydantic +import pytest import stamina +from polyfactory.factories.pydantic_factory import ModelFactory import any_llm_client +from any_llm_client.clients.openai import ChatCompletionsRequest +from any_llm_client.clients.yandexgpt import YandexGPTRequest from tests.conftest import LLMFuncRequest @@ -40,3 +45,13 @@ def test_llm_func_request_has_same_annotations_as_llm_client_methods() -> None: annotations.pop(one_ignored_prop) assert all(annotations == all_annotations[0] for annotations in all_annotations) + + +@pytest.mark.parametrize("model_type", [YandexGPTRequest, ChatCompletionsRequest]) +def test_payload_adds_extra_keys(model_type: type[pydantic.BaseModel]) -> None: + extra: typing.Final = {"hi": "there", "hi-hi": "there-there"} + generated_data: typing.Final = ModelFactory.create_factory(model_type).build(**extra).model_dump(by_alias=True) # type: ignore[arg-type] + dumped_model: typing.Final = model_type(**{**generated_data, **extra}).model_dump(mode="json", by_alias=True) + + assert dumped_model["hi"] == "there" + assert dumped_model["hi-hi"] == "there-there" From 3ce9dd70f25321ee638258e0f2fc2c74f0ca4911 Mon Sep 17 00:00:00 2001 From: Lev Vereshchagin Date: Fri, 22 Nov 2024 16:40:58 +0300 Subject: [PATCH 3/4] Update --- tests/test_static.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_static.py b/tests/test_static.py index 4d73521..5ec106f 100644 --- a/tests/test_static.py +++ b/tests/test_static.py @@ -48,7 +48,7 @@ def test_llm_func_request_has_same_annotations_as_llm_client_methods() -> None: @pytest.mark.parametrize("model_type", [YandexGPTRequest, ChatCompletionsRequest]) -def test_payload_adds_extra_keys(model_type: type[pydantic.BaseModel]) -> None: +def test_dumped_llm_request_payload_dump_has_extra_data(model_type: type[pydantic.BaseModel]) -> None: extra: typing.Final = {"hi": "there", "hi-hi": "there-there"} generated_data: typing.Final = ModelFactory.create_factory(model_type).build(**extra).model_dump(by_alias=True) # type: ignore[arg-type] dumped_model: typing.Final = model_type(**{**generated_data, **extra}).model_dump(mode="json", by_alias=True) From 7ac52bbbae1d4439e21486bfd81a05b9a2f2fc43 Mon Sep 17 00:00:00 2001 From: Lev Vereshchagin Date: Fri, 22 Nov 2024 16:46:22 +0300 Subject: [PATCH 4/4] Update readme --- README.md | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index fdd2a94..395ac8e 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,9 @@ config = any_llm_client.MockLLMConfig( response_message=..., stream_messages=["Hi!"], ) -client = any_llm_client.get_client(config, ...) + +async with any_llm_client.get_client(config, ...) as client: + ... ``` #### Configuration with environment variables @@ -131,7 +133,9 @@ os.environ["LLM_MODEL"] = """{ "model_name": "qwen2.5-coder:1.5b" }""" settings = Settings() -client = any_llm_client.get_client(settings.llm_model, ...) + +async with any_llm_client.get_client(settings.llm_model, ...) as client: + ... ``` Combining with environment variables from previous section, you can keep LLM model configuration and secrets separate. @@ -146,7 +150,9 @@ config = any_llm_client.OpenAIConfig( auth_token=os.environ["OPENAI_API_KEY"], model_name="gpt-4o-mini", ) -client = any_llm_client.OpenAIClient(config, ...) + +async with any_llm_client.OpenAIClient(config, ...) as client: + ... ``` #### Errors @@ -179,5 +185,12 @@ Default timeout is `httpx.Timeout(None, connect=5.0)` (5 seconds on connect, unl By default, requests are retried 3 times on HTTP status errors. You can change the retry behaviour by supplying `request_retry` parameter: ```python -client = any_llm_client.get_client(..., request_retry=any_llm_client.RequestRetryConfig(attempts=5, ...)) +async with any_llm_client.get_client(..., request_retry=any_llm_client.RequestRetryConfig(attempts=5, ...)) as client: + ... +``` + +#### Passing extra data to LLM + +```python +await client.request_llm_message("Кек, чо как вообще на нарах?", extra={"best_of": 3}) ```