Skip to content

Commit

Permalink
Added generic event handler for both tockens and functions calls (#9263)
Browse files Browse the repository at this point in the history
# Description

Main motivation for this PR is to sync with JS langchain
langchain-ai/langchainjs#2025

Added `on_event` callback that works for both token and openai function
calls in streaming mode

Twitter: [@shelfdev](https://twitter.com/ShelfDev)
  • Loading branch information
andrewBatutin committed Aug 25, 2023
1 parent adb2178 commit f771d85
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 6 deletions.
11 changes: 9 additions & 2 deletions libs/langchain/langchain/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.document import Document
from langchain.schema.messages import BaseMessage
from langchain.schema.output import LLMResult
from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult


class RetrieverManagerMixin:
Expand Down Expand Up @@ -43,12 +43,19 @@ class LLMManagerMixin:
def on_llm_new_token(
self,
token: str,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""
"""Run on new LLM token. Only available when streaming is enabled.
Args:
token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information.
"""

def on_llm_end(
self,
Expand Down
3 changes: 3 additions & 0 deletions libs/langchain/langchain/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
LLMResult,
)
from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain.schema.output import ChatGenerationChunk, GenerationChunk

if TYPE_CHECKING:
from langsmith import Client as LangSmithClient
Expand Down Expand Up @@ -655,6 +656,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
async def on_llm_new_token(
self,
token: str,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token.
Expand All @@ -667,6 +669,7 @@ async def on_llm_new_token(
"on_llm_new_token",
"ignore_llm",
token,
chunk,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
Expand Down
10 changes: 8 additions & 2 deletions libs/langchain/langchain/callbacks/tracers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from langchain.callbacks.tracers.schemas import Run
from langchain.load.dump import dumpd
from langchain.schema.document import Document
from langchain.schema.output import ChatGeneration, LLMResult
from langchain.schema.output import (
ChatGeneration,
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -122,6 +127,7 @@ def on_llm_start(
def on_llm_new_token(
self,
token: str,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
Expand All @@ -139,7 +145,7 @@ def on_llm_new_token(
{
"name": "new_token",
"time": datetime.utcnow(),
"kwargs": {"token": token},
"kwargs": {"token": token, "chunk": chunk.dict() if chunk else None},
},
)

Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chat_models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ async def _astream(
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)

async def _agenerate(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
"""Test ChatOpenAI wrapper."""

from typing import Any

from typing import Any

import pytest

from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.manager import CallbackManager
from langchain.chains.openai_functions import (
create_openai_fn_chain,
)
from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import (
ChatGeneration,
ChatResult,
LLMResult,
)
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
from langchain.schema.output import ChatGenerationChunk
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler


Expand Down Expand Up @@ -191,6 +197,97 @@ async def test_async_chat_openai_streaming() -> None:
assert generation.text == generation.message.content


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_async_chat_openai_streaming_with_function() -> None:
"""Test ChatOpenAI wrapper with multiple completions."""

class MyCustomAsyncHandler(AsyncCallbackHandler):
def on_llm_new_token(
self,
token: str,
chunk: ChatGenerationChunk,
**kwargs: Any,
) -> Any:
print(f"I just got a token: {token}")
print(f"I just got a chunk: {chunk}")

json_schema = {
"title": "Person",
"description": "Identifying information about a person.",
"type": "object",
"properties": {
"name": {
"title": "Name",
"description": "The person's name",
"type": "string",
},
"age": {
"title": "Age",
"description": "The person's age",
"type": "integer",
},
"fav_food": {
"title": "Fav Food",
"description": "The person's favorite food",
"type": "string",
},
},
"required": ["name", "age"],
}

callback_handler = MyCustomAsyncHandler()
callback_manager = CallbackManager([callback_handler])

chat = ChatOpenAI(
max_tokens=10,
n=1,
callback_manager=callback_manager,
streaming=True,
)

prompt_msgs = [
SystemMessage(
content="You are a world class algorithm for "
"extracting information in structured formats."
),
HumanMessage(
content="Use the given format to extract "
"information from the following input:"
),
HumanMessagePromptTemplate.from_template("{input}"),
HumanMessage(content="Tips: Make sure to answer in the correct format"),
]
prompt = ChatPromptTemplate(messages=prompt_msgs)

function: Any = {
"name": "output_formatter",
"description": (
"Output formatter. Should always be used to format your response to the"
" user."
),
"parameters": json_schema,
}
chain = create_openai_fn_chain(
[function],
chat,
prompt,
output_parser=None,
)

message = HumanMessage(content="Sally is 13 years old")
response = await chain.agenerate([{"input": message}])

assert isinstance(response, LLMResult)
assert len(response.generations) == 1
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content


def test_chat_openai_extra_kwargs() -> None:
"""Test extra kwargs to chat openai."""
# Check that foo is saved in extra_kwargs.
Expand Down

0 comments on commit f771d85

Please sign in to comment.