From aa8ae31e5b6fe45cfc970740957f228bf44cdced Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Tue, 5 Dec 2023 00:04:48 +0100 Subject: [PATCH] core[patch]: add response kwarg to on_llm_error # Dependencies None # Twitter handle @HKydlicek --------- Co-authored-by: Erick Friis --- libs/core/langchain_core/callbacks/base.py | 16 +++++++- libs/core/langchain_core/callbacks/manager.py | 9 +++++ .../language_models/chat_models.py | 28 ++++++++----- .../langchain_core/language_models/llms.py | 25 +++++++++--- libs/core/tests/unit_tests/fake/callbacks.py | 8 ++-- libs/core/tests/unit_tests/fake/chat_model.py | 16 +++++++- libs/core/tests/unit_tests/fake/llm.py | 18 ++++++++- .../language_models/chat_models/test_base.py | 37 +++++++++++++++++ .../language_models/llms/test_base.py | 40 ++++++++++++++++++- 9 files changed, 172 insertions(+), 25 deletions(-) diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 14078755f4fc5a..ed30e50ff14a48 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -75,7 +75,13 @@ def on_llm_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - """Run when LLM errors.""" + """Run when LLM errors. + Args: + error (BaseException): The error that occurred. + kwargs (Any): Additional keyword arguments. + - response (LLMResult): The response which was generated before + the error occurred. + """ class ChainManagerMixin: @@ -351,7 +357,13 @@ async def on_llm_error( tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - """Run when LLM errors.""" + """Run when LLM errors. + Args: + error (BaseException): The error that occurred. + kwargs (Any): Additional keyword arguments. + - response (LLMResult): The response which was generated before + the error occurred. + """ async def on_chain_start( self, diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 402900c6321408..b1bb0119279f3b 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -623,6 +623,9 @@ def on_llm_error( Args: error (Exception or KeyboardInterrupt): The error. + kwargs (Any): Additional keyword arguments. + - response (LLMResult): The response which was generated before + the error occurred. """ handle_event( self.handlers, @@ -689,6 +692,12 @@ async def on_llm_error( Args: error (Exception or KeyboardInterrupt): The error. + kwargs (Any): Additional keyword arguments. + - response (LLMResult): The response which was generated before + the error occurred. + + + """ await ahandle_event( self.handlers, diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 69bede4a12a501..24bb4114cb4b45 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -223,8 +223,8 @@ def stream( name=config.get("run_name"), batch_size=1, ) + generation: Optional[ChatGenerationChunk] = None try: - generation: Optional[ChatGenerationChunk] = None for chunk in self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ): @@ -235,12 +235,15 @@ def stream( generation += chunk assert generation is not None except BaseException as e: - run_manager.on_llm_error(e) + run_manager.on_llm_error( + e, + response=LLMResult( + generations=[[generation]] if generation else [] + ), + ) raise e else: - run_manager.on_llm_end( - LLMResult(generations=[[generation]]), - ) + run_manager.on_llm_end(LLMResult(generations=[[generation]])) async def astream( self, @@ -277,8 +280,8 @@ async def astream( name=config.get("run_name"), batch_size=1, ) + generation: Optional[ChatGenerationChunk] = None try: - generation: Optional[ChatGenerationChunk] = None async for chunk in self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ): @@ -289,7 +292,12 @@ async def astream( generation += chunk assert generation is not None except BaseException as e: - await run_manager.on_llm_error(e) + await run_manager.on_llm_error( + e, + response=LLMResult( + generations=[[generation]] if generation else [] + ), + ) raise e else: await run_manager.on_llm_end( @@ -366,7 +374,7 @@ def generate( ) except BaseException as e: if run_managers: - run_managers[i].on_llm_error(e) + run_managers[i].on_llm_error(e, response=LLMResult(generations=[])) raise e flattened_outputs = [ LLMResult(generations=[res.generations], llm_output=res.llm_output) @@ -433,7 +441,9 @@ async def agenerate( for i, res in enumerate(results): if isinstance(res, BaseException): if run_managers: - await run_managers[i].on_llm_error(res) + await run_managers[i].on_llm_error( + res, response=LLMResult(generations=[]) + ) exceptions.append(res) if exceptions: if run_managers: diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index dea4375b7f4685..e0e830d10be7e3 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -384,8 +384,8 @@ def stream( name=config.get("run_name"), batch_size=1, ) + generation: Optional[GenerationChunk] = None try: - generation: Optional[GenerationChunk] = None for chunk in self._stream( prompt, stop=stop, run_manager=run_manager, **kwargs ): @@ -396,7 +396,12 @@ def stream( generation += chunk assert generation is not None except BaseException as e: - run_manager.on_llm_error(e) + run_manager.on_llm_error( + e, + response=LLMResult( + generations=[[generation]] if generation else [] + ), + ) raise e else: run_manager.on_llm_end(LLMResult(generations=[[generation]])) @@ -436,8 +441,8 @@ async def astream( name=config.get("run_name"), batch_size=1, ) + generation: Optional[GenerationChunk] = None try: - generation: Optional[GenerationChunk] = None async for chunk in self._astream( prompt, stop=stop, run_manager=run_manager, **kwargs ): @@ -448,7 +453,12 @@ async def astream( generation += chunk assert generation is not None except BaseException as e: - await run_manager.on_llm_error(e) + await run_manager.on_llm_error( + e, + response=LLMResult( + generations=[[generation]] if generation else [] + ), + ) raise e else: await run_manager.on_llm_end(LLMResult(generations=[[generation]])) @@ -539,7 +549,7 @@ def _generate_helper( ) except BaseException as e: for run_manager in run_managers: - run_manager.on_llm_error(e) + run_manager.on_llm_error(e, response=LLMResult(generations=[])) raise e flattened_outputs = output.flatten() for manager, flattened_output in zip(run_managers, flattened_outputs): @@ -707,7 +717,10 @@ async def _agenerate_helper( ) except BaseException as e: await asyncio.gather( - *[run_manager.on_llm_error(e) for run_manager in run_managers] + *[ + run_manager.on_llm_error(e, response=LLMResult(generations=[])) + for run_manager in run_managers + ] ) raise e flattened_outputs = output.flatten() diff --git a/libs/core/tests/unit_tests/fake/callbacks.py b/libs/core/tests/unit_tests/fake/callbacks.py index 2a2af92269fe7f..b2bef343fff887 100644 --- a/libs/core/tests/unit_tests/fake/callbacks.py +++ b/libs/core/tests/unit_tests/fake/callbacks.py @@ -14,6 +14,7 @@ class BaseFakeCallbackHandler(BaseModel): starts: int = 0 ends: int = 0 errors: int = 0 + errors_args: List[Any] = [] text: int = 0 ignore_llm_: bool = False ignore_chain_: bool = False @@ -52,8 +53,9 @@ def on_llm_end_common(self) -> None: self.llm_ends += 1 self.ends += 1 - def on_llm_error_common(self) -> None: + def on_llm_error_common(self, *args: Any, **kwargs: Any) -> None: self.errors += 1 + self.errors_args.append({"args": args, "kwargs": kwargs}) def on_llm_new_token_common(self) -> None: self.llm_streams += 1 @@ -160,7 +162,7 @@ def on_llm_error( *args: Any, **kwargs: Any, ) -> Any: - self.on_llm_error_common() + self.on_llm_error_common(*args, **kwargs) def on_retry( self, @@ -322,7 +324,7 @@ async def on_llm_error( *args: Any, **kwargs: Any, ) -> None: - self.on_llm_error_common() + self.on_llm_error_common(*args, **kwargs) async def on_chain_start( self, diff --git a/libs/core/tests/unit_tests/fake/chat_model.py b/libs/core/tests/unit_tests/fake/chat_model.py index e1268ad4fd3dde..717ab02533f379 100644 --- a/libs/core/tests/unit_tests/fake/chat_model.py +++ b/libs/core/tests/unit_tests/fake/chat_model.py @@ -45,6 +45,7 @@ class FakeListChatModel(SimpleChatModel): responses: List sleep: Optional[float] = None i: int = 0 + error_on_chunk_number: Optional[int] = None @property def _llm_type(self) -> str: @@ -77,9 +78,15 @@ def _stream( self.i += 1 else: self.i = 0 - for c in response: + for i_c, c in enumerate(response): if self.sleep is not None: time.sleep(self.sleep) + if ( + self.error_on_chunk_number is not None + and i_c == self.error_on_chunk_number + ): + raise Exception("Fake error") + yield ChatGenerationChunk(message=AIMessageChunk(content=c)) async def _astream( @@ -94,9 +101,14 @@ async def _astream( self.i += 1 else: self.i = 0 - for c in response: + for i_c, c in enumerate(response): if self.sleep is not None: await asyncio.sleep(self.sleep) + if ( + self.error_on_chunk_number is not None + and i_c == self.error_on_chunk_number + ): + raise Exception("Fake error") yield ChatGenerationChunk(message=AIMessageChunk(content=c)) @property diff --git a/libs/core/tests/unit_tests/fake/llm.py b/libs/core/tests/unit_tests/fake/llm.py index 1ebff8d8ca1b99..165e5b3d2df8de 100644 --- a/libs/core/tests/unit_tests/fake/llm.py +++ b/libs/core/tests/unit_tests/fake/llm.py @@ -60,6 +60,8 @@ def _identifying_params(self) -> Mapping[str, Any]: class FakeStreamingListLLM(FakeListLLM): """Fake streaming list LLM for testing purposes.""" + error_on_chunk_number: Optional[int] = None + def stream( self, input: LanguageModelInput, @@ -69,9 +71,15 @@ def stream( **kwargs: Any, ) -> Iterator[str]: result = self.invoke(input, config) - for c in result: + for i_c, c in enumerate(result): if self.sleep is not None: time.sleep(self.sleep) + + if ( + self.error_on_chunk_number is not None + and i_c == self.error_on_chunk_number + ): + raise Exception("Fake error") yield c async def astream( @@ -83,7 +91,13 @@ async def astream( **kwargs: Any, ) -> AsyncIterator[str]: result = await self.ainvoke(input, config) - for c in result: + for i_c, c in enumerate(result): if self.sleep is not None: await asyncio.sleep(self.sleep) + + if ( + self.error_on_chunk_number is not None + and i_c == self.error_on_chunk_number + ): + raise Exception("Fake error") yield c diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 0f406a06aef28b..24c49f79a3f490 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -1,8 +1,15 @@ """Test base chat model.""" + import pytest from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.outputs.llm_result import LLMResult from langchain_core.tracers.context import collect_runs +from tests.unit_tests.fake.callbacks import ( + BaseFakeCallbackHandler, + FakeAsyncCallbackHandler, + FakeCallbackHandler, +) from tests.unit_tests.fake.chat_model import FakeListChatModel @@ -69,3 +76,33 @@ async def test_async_batch_size(messages: list, messages_2: list) -> None: pass assert len(cb.traced_runs) == 1 assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1 + + +async def test_stream_error_callback() -> None: + message = "test" + + def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None: + assert callback.errors == 1 + assert len(callback.errors_args) == 1 + llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"] + if i == 0: + assert llm_result.generations == [] + else: + assert llm_result.generations[0][0].text == message[:i] + + for i in range(0, 2): + llm = FakeListChatModel( + responses=[message], + error_on_chunk_number=i, + ) + with pytest.raises(Exception): + cb_async = FakeAsyncCallbackHandler() + async for _ in llm.astream("Dummy message", callbacks=[cb_async]): + pass + eval_response(cb_async, i) + + cb_sync = FakeCallbackHandler() + for _ in llm.stream("Dumy message", callbacks=[cb_sync]): + pass + + eval_response(cb_sync, i) diff --git a/libs/core/tests/unit_tests/language_models/llms/test_base.py b/libs/core/tests/unit_tests/language_models/llms/test_base.py index 37b81a0ed22a91..a6e866cf97627b 100644 --- a/libs/core/tests/unit_tests/language_models/llms/test_base.py +++ b/libs/core/tests/unit_tests/language_models/llms/test_base.py @@ -1,5 +1,13 @@ +import pytest + +from langchain_core.outputs.llm_result import LLMResult from langchain_core.tracers.context import collect_runs -from tests.unit_tests.fake.llm import FakeListLLM +from tests.unit_tests.fake.callbacks import ( + BaseFakeCallbackHandler, + FakeAsyncCallbackHandler, + FakeCallbackHandler, +) +from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM def test_batch() -> None: @@ -75,3 +83,33 @@ async def test_async_batch_size() -> None: pass assert len(cb.traced_runs) == 1 assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1 + + +async def test_stream_error_callback() -> None: + message = "test" + + def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None: + assert callback.errors == 1 + assert len(callback.errors_args) == 1 + llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"] + if i == 0: + assert llm_result.generations == [] + else: + assert llm_result.generations[0][0].text == message[:i] + + for i in range(0, 2): + llm = FakeStreamingListLLM( + responses=[message], + error_on_chunk_number=i, + ) + with pytest.raises(Exception): + cb_async = FakeAsyncCallbackHandler() + async for _ in llm.astream("Dummy message", callbacks=[cb_async]): + pass + eval_response(cb_async, i) + + cb_sync = FakeCallbackHandler() + for _ in llm.stream("Dumy message", callbacks=[cb_sync]): + pass + + eval_response(cb_sync, i)