From b61ec3d8e74ba8d5ad0a02741548ff6be52abc43 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Mon, 18 Mar 2024 12:14:22 +0000 Subject: [PATCH 1/2] Throw if stream=True --- src/cohere/client.py | 21 +++++++++++++++++++++ tests/test_client.py | 7 +++++++ 2 files changed, 28 insertions(+) diff --git a/src/cohere/client.py b/src/cohere/client.py index d503c5610..0db105e3c 100644 --- a/src/cohere/client.py +++ b/src/cohere/client.py @@ -6,6 +6,23 @@ from .environment import CohereEnvironment +def validate_args(obj: typing.Any, method_name: str, check_fn) -> typing.Any: + method = getattr(obj, method_name) + + def wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + check_fn(*args, **kwargs) + return method(*args, **kwargs) + + setattr(obj, method_name, wrapped) + + +def throw_if_stream_is_true(*args, **kwargs) -> None: + if kwargs.get("stream") is True: + raise ValueError( + "Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)" + ) + + class Client(BaseCohere): def __init__( self, @@ -27,6 +44,8 @@ def __init__( httpx_client=httpx_client, ) + validate_args(self, "chat", throw_if_stream_is_true) + class AsyncClient(AsyncBaseCohere): def __init__( @@ -48,3 +67,5 @@ def __init__( timeout=timeout, httpx_client=httpx_client, ) + + validate_args(self, "chat", throw_if_stream_is_true) \ No newline at end of file diff --git a/tests/test_client.py b/tests/test_client.py index 58b3d5b66..8baff6af2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -28,6 +28,13 @@ def test_chat(self) -> None: print(chat) + def test_stream_equals_true(self) -> None: + with self.assertRaises(ValueError): + co.chat( + stream=True, # type: ignore + message="What year was he born?", + ) + def test_generate(self) -> None: response = co.generate( prompt='Please explain to me how LLMs work', From c6589e9698019966531d5e7a9451ac8fb648a008 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Mon, 18 Mar 2024 12:40:32 +0000 Subject: [PATCH 2/2] Fix types --- src/cohere/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cohere/client.py b/src/cohere/client.py index 0db105e3c..e7603a293 100644 --- a/src/cohere/client.py +++ b/src/cohere/client.py @@ -6,7 +6,7 @@ from .environment import CohereEnvironment -def validate_args(obj: typing.Any, method_name: str, check_fn) -> typing.Any: +def validate_args(obj: typing.Any, method_name: str, check_fn: typing.Callable[[typing.Any], typing.Any]) -> None: method = getattr(obj, method_name) def wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: