From 8e5c8843cdba3850b91d4fc4627cdd7ef89bfa75 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Tue, 19 Mar 2024 17:14:31 +0000 Subject: [PATCH 1/4] Add wait and fixes types --- src/cohere/client.py | 11 ++- src/cohere/utils.py | 157 +++++++++++++++++++++++++++++++++++++ tests/test_async_client.py | 81 +++++++++---------- tests/test_client.py | 19 ++--- 4 files changed, 210 insertions(+), 58 deletions(-) create mode 100644 src/cohere/utils.py diff --git a/src/cohere/client.py b/src/cohere/client.py index abdf34070..3ae1a03b9 100644 --- a/src/cohere/client.py +++ b/src/cohere/client.py @@ -4,6 +4,8 @@ from .base_client import BaseCohere, AsyncBaseCohere from .environment import ClientEnvironment +from .utils import wait, async_wait + # Use NoReturn as Never type for compatibility Never = typing.NoReturn @@ -25,6 +27,7 @@ def throw_if_stream_is_true(*args, **kwargs) -> None: "Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)" ) + def moved_function(fn_name: str, new_fn_name: str) -> typing.Any: """ This method is moved. Please update usage. @@ -56,7 +59,7 @@ def fn(*args, **kwargs): class Client(BaseCohere): def __init__( self, - api_key: typing.Union[str, typing.Callable[[], str]], + api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None, *, base_url: typing.Optional[str] = None, environment: ClientEnvironment = ClientEnvironment.PRODUCTION, @@ -76,6 +79,8 @@ def __init__( validate_args(self, "chat", throw_if_stream_is_true) + wait = wait + """ The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues. @@ -125,7 +130,7 @@ def __init__( class AsyncClient(AsyncBaseCohere): def __init__( self, - api_key: typing.Union[str, typing.Callable[[], str]], + api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None, *, base_url: typing.Optional[str] = None, environment: ClientEnvironment = ClientEnvironment.PRODUCTION, @@ -145,6 +150,8 @@ def __init__( validate_args(self, "chat", throw_if_stream_is_true) + wait = async_wait + """ The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues. diff --git a/src/cohere/utils.py b/src/cohere/utils.py new file mode 100644 index 000000000..8b551b544 --- /dev/null +++ b/src/cohere/utils.py @@ -0,0 +1,157 @@ +import asyncio +import asyncio +import time +import typing +from typing import Awaitable, Optional + +from .types import EmbedJob, CreateEmbedJobResponse +from .datasets import DatasetsCreateResponse, DatasetsGetResponse + + +def is_dataset_create_response(awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse]) -> bool: + return isinstance(awaitable, "DatasetsCreateResponse") + + +def get_terminal_states(): + return get_success_states() | get_failed_states() + + +def get_success_states(): + return {"complete", "validated"} + + +def get_failed_states(): + return {"unknown", "failed", "skipped", "cancelled", "failed"} + + +def get_id( + awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> str: + return getattr(awaitable, "job_id", None) or getattr(awaitable, "id", None) or getattr( + getattr(awaitable, "dataset", None), "id", None) + + +def get_validation_status(awaitable: typing.Union[EmbedJob, DatasetsGetResponse]) -> str: + return getattr(awaitable, "status", None) or getattr(getattr(awaitable, "dataset", None), "validation_status", None) + + +def get_job(cohere: typing.Any, + awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> \ +typing.Union[ + EmbedJob, DatasetsGetResponse]: + if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse": + return cohere.embed_jobs.get(id=get_id(awaitable)) + elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse": + return cohere.datasets.get(id=get_id(awaitable)) + else: + raise ValueError(f"Unexpected awaitable type {awaitable}") + + +async def async_get_job(cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse]) -> \ + typing.Union[ + EmbedJob, DatasetsGetResponse]: + if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse": + return await cohere.embed_jobs.get(id=get_id(awaitable)) + elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse": + return await cohere.datasets.get(id=get_id(awaitable)) + else: + raise ValueError(f"Unexpected awaitable type {awaitable}") + + +def get_failure_reason(job: typing.Union[EmbedJob, DatasetsGetResponse]) -> Optional[str]: + if isinstance(job, EmbedJob): + return f"Embed job {job.job_id} failed with status {job.status}" + elif isinstance(job, DatasetsGetResponse): + return f"Dataset creation {job.dataset.validation_status} failed with status {job.dataset.validation_status}" + return None + + +@typing.overload +def wait( + cohere: typing.Any, + awaitable: CreateEmbedJobResponse, + timeout: Optional[float] = None, + interval: float = 10, +) -> EmbedJob: + ... + + +@typing.overload +def wait( + cohere: typing.Any, + awaitable: DatasetsCreateResponse, + timeout: Optional[float] = None, + interval: float = 10, +) -> DatasetsGetResponse: + ... + + +def wait( + cohere: typing.Any, + awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse], + timeout: Optional[float] = None, + interval: float = 2, +) -> EmbedJob | DatasetsGetResponse: + start_time = time.time() + terminal_states = get_terminal_states() + failed_states = get_failed_states() + + job = get_job(cohere, awaitable) + while get_validation_status(job) not in terminal_states: + if timeout is not None and time.time() - start_time > timeout: + raise TimeoutError(f"wait timed out after {timeout} seconds") + + time.sleep(interval) + print("...") + + job = get_job(cohere, awaitable) + + if get_validation_status(job) in failed_states: + raise Exception(get_failure_reason(job)) + + return job + + +@typing.overload +def async_wait( + cohere: typing.Any, + awaitable: CreateEmbedJobResponse, + timeout: Optional[float] = None, + interval: float = 10, +) -> Awaitable[EmbedJob]: + ... + + +@typing.overload +def async_wait( + cohere: typing.Any, + awaitable: DatasetsCreateResponse, + timeout: Optional[float] = None, + interval: float = 10, +) -> Awaitable[DatasetsGetResponse]: + ... + + +async def async_wait( + cohere: typing.Any, + awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse], + timeout: Optional[float] = None, + interval: float = 10, +) -> EmbedJob | DatasetsGetResponse: + start_time = time.time() + terminal_states = get_terminal_states() + failed_states = get_failed_states() + + job = await async_get_job(cohere, awaitable) + while get_validation_status(job) not in terminal_states: + if timeout is not None and time.time() - start_time > timeout: + raise TimeoutError(f"wait timed out after {timeout} seconds") + + await asyncio.sleep(interval) + print("...") + + job = await async_get_job(cohere, awaitable) + + if get_validation_status(job) in failed_states: + raise Exception(get_failure_reason(job)) + + return job diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 412b2a24d..07dee5089 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -1,21 +1,22 @@ import os import unittest -from time import sleep import cohere from cohere import ChatMessage, ChatConnector, ClassifyExample, CreateConnectorServiceAuth, Tool, \ ToolParameterDefinitionsValue, ChatRequestToolResultsItem -co = cohere.AsyncClient(os.environ['COHERE_API_KEY'], timeout=10000) - package_dir = os.path.dirname(os.path.abspath(__file__)) embed_job = os.path.join(package_dir, 'embed_job.jsonl') -class TestClient(unittest.TestCase): +class TestClient(unittest.IsolatedAsyncioTestCase): + co: cohere.AsyncClient + + def setUp(self) -> None: + self.co = cohere.AsyncClient(os.environ['COHERE_API_KEY'], timeout=10000) async def test_chat(self) -> None: - chat = await co.chat( + chat = await self.co.chat( chat_history=[ ChatMessage(role="USER", message="Who discovered gravity?"), @@ -29,7 +30,7 @@ async def test_chat(self) -> None: print(chat) async def test_chat_stream(self) -> None: - stream = co.chat_stream( + stream = self.co.chat_stream( chat_history=[ ChatMessage(role="USER", message="Who discovered gravity?"), @@ -46,27 +47,27 @@ async def test_chat_stream(self) -> None: async def test_stream_equals_true(self) -> None: with self.assertRaises(ValueError): - await co.chat( + await self.co.chat( stream=True, # type: ignore message="What year was he born?", ) async def test_deprecated_fn(self) -> None: with self.assertRaises(ValueError): - await co.check_api_key("dummy", dummy="dummy") # type: ignore + await self.co.check_api_key("dummy", dummy="dummy") # type: ignore async def test_moved_fn(self) -> None: with self.assertRaises(ValueError): - await co.list_connectors("dummy", dummy="dummy") # type: ignore + await self.co.list_connectors("dummy", dummy="dummy") # type: ignore async def test_generate(self) -> None: - response = await co.generate( + response = await self.co.generate( prompt='Please explain to me how LLMs work', ) print(response) async def test_embed(self) -> None: - response = await co.embed( + response = await self.co.embed( texts=['hello', 'goodbye'], model='embed-english-v3.0', input_type="classification" @@ -74,21 +75,18 @@ async def test_embed(self) -> None: print(response) async def test_embed_job_crud(self) -> None: - dataset = await co.datasets.create( + dataset = await self.co.datasets.create( name="test", type="embed-input", data=open(embed_job, 'rb'), ) - while True: - ds = await co.datasets.get(dataset.id or "") - sleep(2) - print(ds, flush=True) - if ds.dataset.validation_status != "processing": - break + result = await self.co.wait(dataset) + + self.assertEqual(result.dataset.validation_status, "validated") # start an embed job - job = await co.embed_jobs.create( + job = await self.co.embed_jobs.create( dataset_id=dataset.id or "", input_type="search_document", model='embed-english-v3.0') @@ -96,20 +94,17 @@ async def test_embed_job_crud(self) -> None: print(job) # list embed jobs - my_embed_jobs = await co.embed_jobs.list() + my_embed_jobs = await self.co.embed_jobs.list() print(my_embed_jobs) - while True: - em = await co.embed_jobs.get(job.job_id) - sleep(2) - print(em, flush=True) - if em.status != "processing": - break + result = await self.co.wait(job) + + self.assertEqual(result.status, "complete") - await co.embed_jobs.cancel(job.job_id) + await self.co.embed_jobs.cancel(job.job_id) - await co.datasets.delete(dataset.id or "") + await self.co.datasets.delete(dataset.id or "") async def test_rerank(self) -> None: docs = [ @@ -118,7 +113,7 @@ async def test_rerank(self) -> None: 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.', 'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.'] - response = await co.rerank( + response = await self.co.rerank( model='rerank-english-v2.0', query='What is the capital of the United States?', documents=docs, @@ -148,14 +143,14 @@ async def test_classify(self) -> None: "Confirm your email address", "hey i need u to send some $", ] - response = await co.classify( + response = await self.co.classify( inputs=inputs, examples=examples, ) print(response) async def test_datasets_crud(self) -> None: - my_dataset = await co.datasets.create( + my_dataset = await self.co.datasets.create( name="test", type="embed-input", data=open(embed_job, 'rb'), @@ -163,15 +158,15 @@ async def test_datasets_crud(self) -> None: print(my_dataset) - my_datasets = await co.datasets.list() + my_datasets = await self.co.datasets.list() print(my_datasets) - dataset = await co.datasets.get(my_dataset.id or "") + dataset = await self.co.datasets.get(my_dataset.id or "") print(dataset) - await co.datasets.delete(my_dataset.id or "") + await self.co.datasets.delete(my_dataset.id or "") async def test_summarize(self) -> None: text = ( @@ -197,28 +192,28 @@ async def test_summarize(self) -> None: "lactose intolerant, allergic to dairy protein or vegan." ) - response = await co.summarize( + response = await self.co.summarize( text=text, ) print(response) async def test_tokenize(self) -> None: - response = await co.tokenize( + response = await self.co.tokenize( text='tokenize me! :D', model='command' ) print(response) async def test_detokenize(self) -> None: - response = await co.detokenize( + response = await self.co.detokenize( tokens=[10104, 12221, 1315, 34, 1420, 69], model="command" ) print(response) async def test_connectors_crud(self) -> None: - created_connector = await co.connectors.create( + created_connector = await self.co.connectors.create( name="Example connector", url="https://dummy-connector-o5btz7ucgq-uc.a.run.app/search", service_auth=CreateConnectorServiceAuth( @@ -228,16 +223,16 @@ async def test_connectors_crud(self) -> None: ) print(created_connector) - connector = await co.connectors.get(created_connector.connector.id) + connector = await self.co.connectors.get(created_connector.connector.id) print(connector) - updated_connector = await co.connectors.update( + updated_connector = await self.co.connectors.update( id=connector.connector.id, name="new name") print(updated_connector) - await co.connectors.delete(created_connector.connector.id) + await self.co.connectors.delete(created_connector.connector.id) async def test_tool_use(self) -> None: tools = [ @@ -253,7 +248,7 @@ async def test_tool_use(self) -> None: ) ] - tool_parameters_response = await co.chat( + tool_parameters_response = await self.co.chat( message="How good were the sales on September 29?", tools=tools, model="command-nightly", @@ -291,7 +286,7 @@ async def test_tool_use(self) -> None: outputs=outputs )) - cited_response = await co.chat( + cited_response = await self.co.chat( message="How good were the sales on September 29?", tools=tools, tool_results=tool_results, diff --git a/tests/test_client.py b/tests/test_client.py index 9b2d75ee7..545b77ec8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,5 @@ import os import unittest -from time import sleep import cohere from cohere import ChatMessage, ChatConnector, ClassifyExample, CreateConnectorServiceAuth, Tool, \ @@ -81,12 +80,9 @@ def test_embed_job_crud(self) -> None: data=open(embed_job, 'rb'), ) - while True: - ds = co.datasets.get(dataset.id or "") - sleep(2) - print(ds, flush=True) - if ds.dataset.validation_status != "processing": - break + result = co.wait(dataset) + + self.assertEqual(result.dataset.validation_status, "validated") # start an embed job job = co.embed_jobs.create( @@ -101,12 +97,9 @@ def test_embed_job_crud(self) -> None: print(my_embed_jobs) - while True: - em = co.embed_jobs.get(job.job_id) - sleep(2) - print(em, flush=True) - if em.status != "processing": - break + result = co.wait(job) + + self.assertEqual(result.status, "complete") co.embed_jobs.cancel(job.job_id) From 2c671f95163946ea19dbc9b57fda894e6cee41e7 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Tue, 19 Mar 2024 17:19:26 +0000 Subject: [PATCH 2/4] Fix types --- src/cohere/utils.py | 17 +++++++---------- tests/test_async_client.py | 4 ++-- tests/test_client.py | 4 ++-- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/cohere/utils.py b/src/cohere/utils.py index 8b551b544..35e6f422b 100644 --- a/src/cohere/utils.py +++ b/src/cohere/utils.py @@ -4,14 +4,11 @@ import typing from typing import Awaitable, Optional +from . import EmbedJob, DatasetsGetResponse from .types import EmbedJob, CreateEmbedJobResponse from .datasets import DatasetsCreateResponse, DatasetsGetResponse -def is_dataset_create_response(awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse]) -> bool: - return isinstance(awaitable, "DatasetsCreateResponse") - - def get_terminal_states(): return get_success_states() | get_failed_states() @@ -25,12 +22,12 @@ def get_failed_states(): def get_id( - awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> str: + awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]): return getattr(awaitable, "job_id", None) or getattr(awaitable, "id", None) or getattr( getattr(awaitable, "dataset", None), "id", None) -def get_validation_status(awaitable: typing.Union[EmbedJob, DatasetsGetResponse]) -> str: +def get_validation_status(awaitable: typing.Union[EmbedJob, DatasetsGetResponse]): return getattr(awaitable, "status", None) or getattr(getattr(awaitable, "dataset", None), "validation_status", None) @@ -112,22 +109,22 @@ def wait( @typing.overload -def async_wait( +async def async_wait( cohere: typing.Any, awaitable: CreateEmbedJobResponse, timeout: Optional[float] = None, interval: float = 10, -) -> Awaitable[EmbedJob]: +) -> EmbedJob: ... @typing.overload -def async_wait( +async def async_wait( cohere: typing.Any, awaitable: DatasetsCreateResponse, timeout: Optional[float] = None, interval: float = 10, -) -> Awaitable[DatasetsGetResponse]: +) -> DatasetsGetResponse: ... diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 07dee5089..042172337 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -98,9 +98,9 @@ async def test_embed_job_crud(self) -> None: print(my_embed_jobs) - result = await self.co.wait(job) + emb_result = await self.co.wait(job) - self.assertEqual(result.status, "complete") + self.assertEqual(emb_result.status, "complete") await self.co.embed_jobs.cancel(job.job_id) diff --git a/tests/test_client.py b/tests/test_client.py index 545b77ec8..56a15a8eb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -97,9 +97,9 @@ def test_embed_job_crud(self) -> None: print(my_embed_jobs) - result = co.wait(job) + emb_result = co.wait(job) - self.assertEqual(result.status, "complete") + self.assertEqual(emb_result.status, "complete") co.embed_jobs.cancel(job.job_id) From 9416e734f90f2f28d13c4e7b7e74565210895446 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Tue, 19 Mar 2024 17:21:05 +0000 Subject: [PATCH 3/4] Fix imports --- src/cohere/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/cohere/utils.py b/src/cohere/utils.py index 35e6f422b..1f63bfa4a 100644 --- a/src/cohere/utils.py +++ b/src/cohere/utils.py @@ -1,10 +1,8 @@ import asyncio -import asyncio import time import typing -from typing import Awaitable, Optional +from typing import Optional -from . import EmbedJob, DatasetsGetResponse from .types import EmbedJob, CreateEmbedJobResponse from .datasets import DatasetsCreateResponse, DatasetsGetResponse From cf1f466872425485c1e5c33fffd17a7f2f345123 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Tue, 19 Mar 2024 17:22:24 +0000 Subject: [PATCH 4/4] Fix type --- src/cohere/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cohere/utils.py b/src/cohere/utils.py index 1f63bfa4a..249713f25 100644 --- a/src/cohere/utils.py +++ b/src/cohere/utils.py @@ -85,7 +85,7 @@ def wait( awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse], timeout: Optional[float] = None, interval: float = 2, -) -> EmbedJob | DatasetsGetResponse: +) -> typing.Union[EmbedJob, DatasetsGetResponse]: start_time = time.time() terminal_states = get_terminal_states() failed_states = get_failed_states() @@ -131,7 +131,7 @@ async def async_wait( awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse], timeout: Optional[float] = None, interval: float = 10, -) -> EmbedJob | DatasetsGetResponse: +) -> typing.Union[EmbedJob, DatasetsGetResponse]: start_time = time.time() terminal_states = get_terminal_states() failed_states = get_failed_states()