# Streaming With Callbacks (Advanced)

In this notebook, we'll see how to combine agents with callbacks to achieve **token by token streaming** from the underlying tools!

Our agent will use the OpenAI tools API for tool invocation, and we'll provide the agent with two tools:

1. `where_cat_is_hiding`: A tool that uses an LLM to tell us where the cat is hiding
2. `tell_me_a_joke_about`: A tool that can use an LLM to tell a joke about the given topic


## astream_log

**token by token streaming** can also be built entirely with astream_log as shown in the agent streaming notebook.

## calllbacks

In this example, we'll declare a callback that will print tokens to stdout based on whether they were generated from a model with matching tags.

In a production setting, instead of printing the tokens to stdout, you can send them elsewhere (e.g., write them to a streaming response on your Fast API endpoint)!

Please note that in this example, the callbacks are not passed to the constructor of the LLM, but instead passed as runtime parameters to `.invoke` or `.stream` methods. 

These callbacks become *inherited* callbacks that will be passed to all sub-dependencies when possible. When declaring tools, the tools must propagate the callbacks to the underlying LLMs for this to work (checkout the tools below to see how that works). This is generally a good practice as it also makes sure that debug traces are constructed correctly with the llm call appearing from inside the tool call.

However, you can pass the callbacks to the constructor which will create a local (aka non-inherited callback), but then you don't have to propagate it.

In [1]:
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
from uuid import UUID

from langchain import agents, hub
from langchain.prompts import ChatPromptTemplate
from langchain_core.callbacks import Callbacks
from langchain.tools import tool
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
from langchain_openai import ChatOpenAI


# Here is a custom handler that will print the tokens to stdout.
# Instead of printing to stdout you can send the data elsewhere; e.g., to a streaming API response
class TokenByTokenHandler(AsyncCallbackHandler):
    def __init__(self, tags_of_interest: List[str]) -> None:
        """A custom call back handler.

        Args:
            tags_of_interest: Only LLM tokens from models with these tags will be
                              printed.
        """
        self.tags_of_interest = tags_of_interest

    async def on_chat_model_start(
            self,
            serialized: Dict[str, Any],
            messages: List[List[BaseMessage]],
            *,
            run_id: UUID,
            parent_run_id: Optional[UUID] = None,
            tags: Optional[List[str]] = None,
            metadata: Optional[Dict[str, Any]] = None,
            **kwargs: Any,
    ) -> Any:
        """Run when a chat model starts running."""
        overlap_tags = self.get_overlap_tags(tags)

        if overlap_tags:
            print(",".join(overlap_tags), end=': ', flush=True)

    async def on_llm_end(
            self,
            response: LLMResult,
            *,
            run_id: UUID,
            parent_run_id: Optional[UUID] = None,
            tags: Optional[List[str]] = None,
            **kwargs: Any,
    ) -> None:
        """Run when LLM ends running."""
        overlap_tags = self.get_overlap_tags(tags)

        if overlap_tags:
            # Who can argue with beauty?
            print()
            print()

    def get_overlap_tags(self, tags: Optional[List[str]]) -> List[str]:
        """Check for overlap with filtered tags."""
        if not tags:
            return []
        return sorted(set(tags or []) & set(self.tags_of_interest or []))

    async def on_llm_new_token(
            self,
            token: str,
            *,
            chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
            run_id: UUID,
            parent_run_id: Optional[UUID] = None,
            tags: Optional[List[str]] = None,
            **kwargs: Any,
    ) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        overlap_tags = self.get_overlap_tags(tags)

        if overlap_tags:
            print(token, end="", flush=True)

In [None]:
from langchain_core.callbacks.manager import CallbackManager

## Create the model

**Attention** For older versions of langchain, we must set `streaming=True`

In [2]:
model = ChatOpenAI(temperature=0, streaming=True)

## Tools

We define two tools that rely on a chat model to generate output!

Please note a few different things:

1. We invoke the model using .stream() to force the output to stream (unfortunately for older langchain versions you should still set `streaming=True` on the model)
2. We attach tags to the model so that we can filter on said tags in our callback handler
3. The tools accept callbacks and propagate them to the model as a runtime argument

In [3]:
@tool
def where_cat_is_hiding(callbacks: Callbacks) -> str:  # <--- Accept callbacks
    """Where is the cat hiding right now?"""
    chunks = list(
        model.stream(
            "Give one up to three word answer about where the cat might be hiding in the house right now.",
            {"tags": ["hiding_spot"], "callbacks": callbacks}, # <--- Propagate callbacks and assign a tag to this model
        )
    )
    return "".join(chunk.content for chunk in chunks)


@tool
def tell_me_a_joke_about(topic: str, callbacks: Callbacks) -> str: # <--- Accept callbacks
    """Tell a joke about a given topic."""
    template = ChatPromptTemplate.from_messages(
        [
            ("system", "You are Cat Agent 007. You are funny and know many jokes."),
            ("human", "Tell me a joke about {topic}"),
        ]
    )
    chain = template | model.with_config({"tags": ["joke"]})
    chunks = list(chain.stream({"topic": topic}, {"callbacks": callbacks})) # <--- Propagate callbacks and assign a tag to this model
    return "".join(chunk.content for chunk in chunks)

## Initialize the Agent

In [4]:
# Get the prompt to use - you can modify this!
prompt = hub.pull("hwchase17/openai-tools-agent")
print(prompt.messages)

[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a helpful assistant')), MessagesPlaceholder(variable_name='chat_history', optional=True), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['input'], template='{input}')), MessagesPlaceholder(variable_name='agent_scratchpad')]


In [5]:
tools = [tell_me_a_joke_about, where_cat_is_hiding]
agent = agents.create_openai_tools_agent(
    model.with_config({"tags": ["agent"]}), tools, prompt
)
executor = agents.AgentExecutor(agent=agent, tools=tools)

In [6]:
handler = TokenByTokenHandler(tags_of_interest=["hiding_spot", "joke"])

result = executor.invoke(
    {"input": "where is the cat hiding?"},
    {"callbacks": [handler]},
)

hiding_spot: Under the bed.



In [7]:
handler = TokenByTokenHandler(tags_of_interest=["hiding_spot", "joke"])

result = executor.invoke(
    {"input": "tell me a joke about the location where the cat is hiding"},
    {"callbacks": [handler]},
)

hiding_spot: Under the bed.

joke: Why did the scarecrow bring a ladder under the bed?

Because it heard there was a "bed spring" party going on!

