Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callbacks Refactor [base] #3256

Merged
merged 45 commits into from
Apr 30, 2023
Merged
Changes from 3 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
3cc2ce6
callbacks changes
agola11 Apr 20, 2023
55c7964
Merge branch 'master' into ankush/callbacks-refactor
agola11 Apr 21, 2023
fa4a4f2
cr
agola11 Apr 21, 2023
675e27c
Callbacks Refactor [2/n]: refactor `CallbackManager` code to own file…
agola11 Apr 23, 2023
90cef7b
cr
agola11 Apr 23, 2023
4cdd19b
Callbacks Refactor [2/n] update tracer to work with new callbacks mec…
agola11 Apr 26, 2023
7bcdc66
fix notebook and warnings
agola11 Apr 26, 2023
6fec15b
write to different session
agola11 Apr 26, 2023
5066869
fix execution order issue
agola11 Apr 27, 2023
e953d2c
mypy
agola11 Apr 27, 2023
6cd653d
cr
agola11 Apr 27, 2023
8ae809a
mypy
agola11 Apr 28, 2023
1fc3941
mypy
agola11 Apr 28, 2023
15c0fa5
cr
agola11 Apr 28, 2023
5dcb44e
fix llm chain
agola11 Apr 28, 2023
da27d87
fix most tests
agola11 Apr 28, 2023
2ed4649
fix baby agi
agola11 Apr 28, 2023
0e81e83
Nc/callbacks docs (#3717)
nfcampos Apr 28, 2023
eb9de30
merge
agola11 Apr 28, 2023
1b48ea8
cr
agola11 Apr 28, 2023
18138c6
cr
agola11 Apr 28, 2023
50f6895
Chains callbacks refactor (#3683)
dev2049 Apr 28, 2023
eeb18c4
Merge branch 'master' of github.com:hwchase17/langchain into ankush/c…
agola11 Apr 28, 2023
40f3f6e
Merge branch 'ankush/callbacks-refactor' of github.com:hwchase17/lang…
agola11 Apr 28, 2023
83cda5e
lint
agola11 Apr 28, 2023
9c876bd
update chain notebooks (#3740)
dev2049 Apr 28, 2023
43410e4
fix test
agola11 Apr 28, 2023
145e1af
Merge branch 'ankush/callbacks-refactor' of github.com:hwchase17/lang…
agola11 Apr 28, 2023
56f16cd
Merge branch 'master' into ankush/callbacks-refactor
agola11 Apr 28, 2023
9c988ae
cr
agola11 Apr 28, 2023
bd9ac67
nb nit (#3744)
dev2049 Apr 28, 2023
e60489e
fix lint
agola11 Apr 28, 2023
3c5f983
Merge branch 'ankush/callbacks-refactor' of github.com:hwchase17/lang…
agola11 Apr 28, 2023
9dad051
fix test warnings (#3753)
dev2049 Apr 29, 2023
5f78219
fix some docs, add session variable
agola11 Apr 29, 2023
290fe75
Add RunManager to Tools Arguments (#3746)
vowelparrot Apr 29, 2023
20ba888
Call Manager for New Tools (#3755)
vowelparrot Apr 29, 2023
a038f37
Resolve merge conflicts
vowelparrot Apr 29, 2023
9192abc
Notebook Nits
vowelparrot Apr 29, 2023
35cc38f
merge
agola11 Apr 29, 2023
ebc6242
fix docs
agola11 Apr 29, 2023
737467a
use UUID
agola11 Apr 29, 2023
3839703
bw compat environ variable
agola11 Apr 29, 2023
fa1742c
fix openai callback
agola11 Apr 29, 2023
fb78f69
cr
hwchase17 Apr 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
104 changes: 37 additions & 67 deletions langchain/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from langchain.schema import AgentAction, AgentFinish, LLMResult


class BaseCallbackHandler(ABC):
class BaseCallbackHandler:
"""Base callback handler that can be used to handle callbacks from langchain."""

@property
Expand All @@ -30,67 +30,54 @@ def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return False

@abstractmethod
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
"""Run when LLM starts running."""

@abstractmethod
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""

@abstractmethod
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
"""Run when LLM ends running."""

@abstractmethod
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
"""Run when LLM errors."""

@abstractmethod
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> Any:
"""Run when chain starts running."""

@abstractmethod
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain ends running."""

@abstractmethod
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
"""Run when chain errors."""

@abstractmethod
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
"""Run when tool starts running."""

@abstractmethod
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""

@abstractmethod
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
"""Run when tool errors."""

@abstractmethod
def on_text(self, text: str, **kwargs: Any) -> Any:
"""Run on arbitrary text."""

@abstractmethod
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""

@abstractmethod
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""

Expand Down Expand Up @@ -127,6 +114,21 @@ def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers

def _handle_event(
self,
event_name: str,
ignore_condition_name: Optional[str],
verbose: bool,
*args: Any,
**kwargs: Any
) -> None:
for handler in self.handlers:
if ignore_condition_name is None or not getattr(
handler, ignore_condition_name
):
if verbose or handler.always_verbose:
getattr(handler, event_name)(*args, **kwargs)

def on_llm_start(
self,
serialized: Dict[str, Any],
Expand All @@ -135,28 +137,21 @@ def on_llm_start(
**kwargs: Any
) -> None:
"""Run when LLM starts running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_start(serialized, prompts, **kwargs)
self._handle_event(
"on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs
)

def on_llm_new_token(
self, token: str, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when LLM generates a new token."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_new_token(token, **kwargs)
self._handle_event("on_llm_new_token", "ignore_llm", verbose, token, **kwargs)

def on_llm_end(
self, response: LLMResult, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when LLM ends running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_end(response)
self._handle_event("on_llm_end", "ignore_llm", verbose, response, **kwargs)

def on_llm_error(
self,
Expand All @@ -165,10 +160,7 @@ def on_llm_error(
**kwargs: Any
) -> None:
"""Run when LLM errors."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_error(error)
self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs)

def on_chain_start(
self,
Expand All @@ -178,19 +170,15 @@ def on_chain_start(
**kwargs: Any
) -> None:
"""Run when chain starts running."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
handler.on_chain_start(serialized, inputs, **kwargs)
self._handle_event(
"on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs
)

def on_chain_end(
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
) -> None:
"""Run when chain ends running."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
handler.on_chain_end(outputs)
self._handle_event("on_chain_end", "ignore_chain", verbose, outputs, **kwargs)

def on_chain_error(
self,
Expand All @@ -199,10 +187,7 @@ def on_chain_error(
**kwargs: Any
) -> None:
"""Run when chain errors."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
handler.on_chain_error(error)
self._handle_event("on_chain_error", "ignore_chain", verbose, error, **kwargs)

def on_tool_start(
self,
Expand All @@ -212,26 +197,19 @@ def on_tool_start(
**kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_tool_start(serialized, input_str, **kwargs)
self._handle_event(
"on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs
)

def on_agent_action(
self, action: AgentAction, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_agent_action(action, **kwargs)
self._handle_event("on_agent_action", "ignore_agent", verbose, action, **kwargs)

def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run when tool ends running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_tool_end(output, **kwargs)
self._handle_event("on_tool_end", "ignore_agent", verbose, output, **kwargs)

def on_tool_error(
self,
Expand All @@ -240,25 +218,17 @@ def on_tool_error(
**kwargs: Any
) -> None:
"""Run when tool errors."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_tool_error(error)
self._handle_event("on_tool_error", "ignore_agent", verbose, error, **kwargs)

def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run on additional input from chains and agents."""
for handler in self.handlers:
if verbose or handler.always_verbose:
handler.on_text(text, **kwargs)
self._handle_event("on_text", None, verbose, text, **kwargs)

def on_agent_finish(
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
) -> None:
"""Run on agent end."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_agent_finish(finish, **kwargs)
self._handle_event("on_agent_finish", "ignore_agent", verbose, finish, **kwargs)

def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
Expand Down Expand Up @@ -328,7 +298,7 @@ async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""


async def _handle_event_for_handler(
async def _ahandle_event_for_handler(
handler: BaseCallbackHandler,
event_name: str,
ignore_condition_name: Optional[str],
Expand Down Expand Up @@ -370,7 +340,7 @@ async def _handle_event(
"""Generic event handler for AsyncCallbackManager."""
await asyncio.gather(
*(
_handle_event_for_handler(
_ahandle_event_for_handler(
handler, event_name, ignore_condition_name, verbose, *args, **kwargs
)
for handler in self.handlers
Expand Down