From 664db5ab03807a5629a04d9997eb8b2ceffab488 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Wed, 2 Aug 2023 19:20:53 -0700 Subject: [PATCH 1/2] Fix Async Retry Event Handling --- libs/langchain/langchain/llms/base.py | 6 +- .../integration_tests/llms/test_openai.py | 60 ------------------- .../callbacks/fake_callback_handler.py | 7 +++ .../tests/unit_tests/llms/test_openai.py | 37 +++++++++--- 4 files changed, 41 insertions(+), 69 deletions(-) diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 8ccaf0deb645c3..7da494de78b64a 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -91,7 +91,11 @@ def _before_sleep(retry_state: RetryCallState) -> None: if isinstance(run_manager, AsyncCallbackManagerForLLMRun): coro = run_manager.on_retry(retry_state) try: - asyncio.run(coro) + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(coro) + else: + asyncio.run(coro) except Exception as e: _log_error_once(f"Error in on_retry: {e}") else: diff --git a/libs/langchain/tests/integration_tests/llms/test_openai.py b/libs/langchain/tests/integration_tests/llms/test_openai.py index 6b584ae154b828..1e48e4ce690017 100644 --- a/libs/langchain/tests/integration_tests/llms/test_openai.py +++ b/libs/langchain/tests/integration_tests/llms/test_openai.py @@ -351,63 +351,3 @@ def mock_completion() -> dict: ], "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, } - - -@pytest.mark.requires("openai") -def test_openai_retries(mock_completion: dict) -> None: - llm = OpenAI() - mock_client = MagicMock() - completed = False - raised = False - import openai - - def raise_once(*args: Any, **kwargs: Any) -> Any: - nonlocal completed, raised - if not raised: - raised = True - raise openai.error.APIError - completed = True - return mock_completion - - mock_client.create = raise_once - callback_handler = FakeCallbackHandler() - with patch.object( - llm, - "client", - mock_client, - ): - res = llm.predict("bar", callbacks=[callback_handler]) - assert res == "Bar Baz" - assert completed - assert raised - assert callback_handler.retries == 1 - - -@pytest.mark.requires("openai") -async def test_openai_async_retries(mock_completion: dict) -> None: - llm = OpenAI() - mock_client = MagicMock() - completed = False - raised = False - import openai - - def raise_once(*args: Any, **kwargs: Any) -> Any: - nonlocal completed, raised - if not raised: - raised = True - raise openai.error.APIError - completed = True - return mock_completion - - mock_client.create = raise_once - callback_handler = FakeAsyncCallbackHandler() - with patch.object( - llm, - "client", - mock_client, - ): - res = llm.apredict("bar", callbacks=[callback_handler]) - assert res == "Bar Baz" - assert completed - assert raised - assert callback_handler.retries == 1 diff --git a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py index 87b56a9bff200f..f4819c6930efbe 100644 --- a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py @@ -290,6 +290,13 @@ def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" return self.ignore_agent_ + async def on_retry( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retry_common() + async def on_llm_start( self, *args: Any, diff --git a/libs/langchain/tests/unit_tests/llms/test_openai.py b/libs/langchain/tests/unit_tests/llms/test_openai.py index cc0fc74c1f7fb2..6e765b5106cab3 100644 --- a/libs/langchain/tests/unit_tests/llms/test_openai.py +++ b/libs/langchain/tests/unit_tests/llms/test_openai.py @@ -1,3 +1,4 @@ +import asyncio import os from typing import Any from unittest.mock import MagicMock, patch @@ -5,6 +6,7 @@ import pytest from langchain.llms.openai import OpenAI +from tests.unit_tests.callbacks.fake_callback_handler import FakeAsyncCallbackHandler, FakeCallbackHandler os.environ["OPENAI_API_KEY"] = "foo" @@ -45,44 +47,63 @@ def mock_completion() -> dict: @pytest.mark.requires("openai") -def test_openai_calls(mock_completion: dict) -> None: +def test_openai_retries(mock_completion: dict) -> None: llm = OpenAI() mock_client = MagicMock() completed = False + raised = False + import openai def raise_once(*args: Any, **kwargs: Any) -> Any: - nonlocal completed + nonlocal completed, raised + if not raised: + raised = True + raise openai.error.APIError completed = True return mock_completion mock_client.create = raise_once + callback_handler = FakeCallbackHandler() with patch.object( llm, "client", mock_client, ): - res = llm.predict("bar") + res = llm.predict("bar", callbacks=[callback_handler]) assert res == "Bar Baz" assert completed + assert raised + assert callback_handler.retries == 1 @pytest.mark.requires("openai") +@pytest.mark.asyncio async def test_openai_async_retries(mock_completion: dict) -> None: llm = OpenAI() mock_client = MagicMock() completed = False - - def raise_once(*args: Any, **kwargs: Any) -> Any: - nonlocal completed + raised = False + import openai + + async def araise_once(*args: Any, **kwargs: Any) -> Any: + nonlocal completed, raised + if not raised: + raised = True + raise openai.error.APIError + asyncio.sleep(0.001) completed = True return mock_completion - mock_client.create = raise_once + mock_client.acreate = araise_once + callback_handler = FakeAsyncCallbackHandler() with patch.object( llm, "client", mock_client, ): - res = llm.apredict("bar") + res = await llm.apredict("bar", callbacks=[callback_handler]) assert res == "Bar Baz" assert completed + assert raised + assert callback_handler.retries == 1 + From 844ef9a62f2640e4f95f69747fb02795e318994c Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 3 Aug 2023 07:51:54 -0700 Subject: [PATCH 2/2] fix lint --- .../langchain/tests/integration_tests/llms/test_openai.py | 4 +--- libs/langchain/tests/unit_tests/llms/test_openai.py | 8 +++++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/libs/langchain/tests/integration_tests/llms/test_openai.py b/libs/langchain/tests/integration_tests/llms/test_openai.py index 1e48e4ce690017..ca8911078a48f9 100644 --- a/libs/langchain/tests/integration_tests/llms/test_openai.py +++ b/libs/langchain/tests/integration_tests/llms/test_openai.py @@ -1,7 +1,6 @@ """Test OpenAI API wrapper.""" from pathlib import Path -from typing import Any, Generator -from unittest.mock import MagicMock, patch +from typing import Generator import pytest @@ -11,7 +10,6 @@ from langchain.llms.openai import OpenAI, OpenAIChat from langchain.schema import LLMResult from tests.unit_tests.callbacks.fake_callback_handler import ( - FakeAsyncCallbackHandler, FakeCallbackHandler, ) diff --git a/libs/langchain/tests/unit_tests/llms/test_openai.py b/libs/langchain/tests/unit_tests/llms/test_openai.py index 6e765b5106cab3..54750a95921cc5 100644 --- a/libs/langchain/tests/unit_tests/llms/test_openai.py +++ b/libs/langchain/tests/unit_tests/llms/test_openai.py @@ -6,7 +6,10 @@ import pytest from langchain.llms.openai import OpenAI -from tests.unit_tests.callbacks.fake_callback_handler import FakeAsyncCallbackHandler, FakeCallbackHandler +from tests.unit_tests.callbacks.fake_callback_handler import ( + FakeAsyncCallbackHandler, + FakeCallbackHandler, +) os.environ["OPENAI_API_KEY"] = "foo" @@ -90,7 +93,7 @@ async def araise_once(*args: Any, **kwargs: Any) -> Any: if not raised: raised = True raise openai.error.APIError - asyncio.sleep(0.001) + await asyncio.sleep(0) completed = True return mock_completion @@ -106,4 +109,3 @@ async def araise_once(*args: Any, **kwargs: Any) -> Any: assert completed assert raised assert callback_handler.retries == 1 -