Skip to content

Commit

Permalink
feat: Add streaming only final aiter of agent (#6274)
Browse files Browse the repository at this point in the history
<!--
Thank you for contributing to LangChain! Your PR will appear in our
release under the title you set. Please make sure it highlights your
valuable contribution.

Replace this with a description of the change, the issue it fixes (if
applicable), and relevant context. List any dependencies required for
this change.

After you're done, someone will review your PR. They may suggest
improvements. If no one reviews your PR within a few days, feel free to
@-mention the same people again, as notifications can get lost.

Finally, we'd love to show appreciation for your contribution - if you'd
like us to shout you out on Twitter, please also include your handle!
-->

<!-- Remove if not applicable -->

#### Add streaming only final async iterator of agent
This callback returns an async iterator and only streams the final
output of an agent.

<!-- If you're adding a new integration, please include:

1. a test for the integration - favor unit tests that does not rely on
network access.
2. an example notebook showing its use


See contribution guidelines for more information on how to write tests,
lint
etc:


https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
-->

#### Who can review?

Tag maintainers/contributors who might be interested: @agola11

<!-- For a quicker response, figure out the right person to tag with @

  @hwchase17 - project lead

  Tracing / Callbacks
  - @agola11

  Async
  - @agola11

  DataLoaders
  - @eyurtsev

  Models
  - @hwchase17
  - @agola11

  Agents / Tools / Toolkits
  - @hwchase17

  VectorStores / Retrievers / Memory
  - @dev2049

 -->
  • Loading branch information
ninely authored and hinthornw committed Jul 3, 2023
1 parent 6f88bd5 commit 8c92fb3
Showing 1 changed file with 88 additions and 0 deletions.
88 changes: 88 additions & 0 deletions langchain/callbacks/streaming_aiter_final_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

from typing import Any, Dict, List, Optional

from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.schema import LLMResult

DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]


class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
"""Callback handler that returns an async iterator.
Only the final output of the agent will be iterated.
"""

def append_to_last_tokens(self, token: str) -> None:
self.last_tokens.append(token)
self.last_tokens_stripped.append(token.strip())
if len(self.last_tokens) > len(self.answer_prefix_tokens):
self.last_tokens.pop(0)
self.last_tokens_stripped.pop(0)

def check_if_answer_reached(self) -> bool:
if self.strip_tokens:
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
else:
return self.last_tokens == self.answer_prefix_tokens

def __init__(
self,
*,
answer_prefix_tokens: Optional[List[str]] = None,
strip_tokens: bool = True,
stream_prefix: bool = False,
) -> None:
"""Instantiate AsyncFinalIteratorCallbackHandler.
Args:
answer_prefix_tokens: Token sequence that prefixes the answer.
Default is ["Final", "Answer", ":"]
strip_tokens: Ignore white spaces and new lines when comparing
answer_prefix_tokens to last tokens? (to determine if answer has been
reached)
stream_prefix: Should answer prefix itself also be streamed?
"""
super().__init__()
if answer_prefix_tokens is None:
self.answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS
else:
self.answer_prefix_tokens = answer_prefix_tokens
if strip_tokens:
self.answer_prefix_tokens_stripped = [
token.strip() for token in self.answer_prefix_tokens
]
else:
self.answer_prefix_tokens_stripped = self.answer_prefix_tokens
self.last_tokens = [""] * len(self.answer_prefix_tokens)
self.last_tokens_stripped = [""] * len(self.answer_prefix_tokens)
self.strip_tokens = strip_tokens
self.stream_prefix = stream_prefix
self.answer_reached = False

async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
# If two calls are made in a row, this resets the state
self.done.clear()
self.answer_reached = False

async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if self.answer_reached:
self.done.set()

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
# Remember the last n tokens, where n = len(answer_prefix_tokens)
self.append_to_last_tokens(token)

# Check if the last n tokens match the answer_prefix_tokens list ...
if self.check_if_answer_reached():
self.answer_reached = True
if self.stream_prefix:
for t in self.last_tokens:
self.queue.put_nowait(t)
return

# If yes, then put tokens from now on
if self.answer_reached:
self.queue.put_nowait(token)

0 comments on commit 8c92fb3

Please sign in to comment.