Skip to content

Commit

Permalink
anthropic: refactor streaming to use events api; add streaming usage …
Browse files Browse the repository at this point in the history
…metadata (#22628)

- Refactor streaming to use raw events;
- Add `stream_usage` class attribute and kwarg to stream methods that,
if True, will include separate chunks in the stream containing usage
metadata.

There are two ways to implement streaming with anthropic's python sdk.
They have slight differences in how they surface usage metadata.
1. [Use helper
functions](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#streaming-helpers).
This is what we are doing now.
```python
count = 1
with client.messages.stream(**params) as stream:
    for text in stream.text_stream:
        snapshot = stream.current_message_snapshot
        print(f"{count}: {snapshot.usage} -- {text}")
        count = count + 1

final_snapshot = stream.get_final_message()
print(f"{count}: {final_snapshot.usage}")
```
```
1: Usage(input_tokens=8, output_tokens=1) -- Hello
2: Usage(input_tokens=8, output_tokens=1) -- !
3: Usage(input_tokens=8, output_tokens=1) --  How
4: Usage(input_tokens=8, output_tokens=1) --  can
5: Usage(input_tokens=8, output_tokens=1) --  I
6: Usage(input_tokens=8, output_tokens=1) --  assist
7: Usage(input_tokens=8, output_tokens=1) --  you
8: Usage(input_tokens=8, output_tokens=1) --  today
9: Usage(input_tokens=8, output_tokens=1) -- ?
10: Usage(input_tokens=8, output_tokens=12)
```
To do this correctly, we need to emit a new chunk at the end of the
stream containing the usage metadata.

2. [Handle raw
events](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#streaming-responses)
```python
stream = client.messages.create(**params, stream=True)
count = 1
for event in stream:
    print(f"{count}: {event}")
    count = count + 1
```
```
1: RawMessageStartEvent(message=Message(id='msg_01Vdyov2kADZTXqSKkfNJXcS', content=[], model='claude-3-haiku-20240307', role='assistant', stop_reason=None, stop_sequence=None, type='message', usage=Usage(input_tokens=8, output_tokens=1)), type='message_start')
2: RawContentBlockStartEvent(content_block=TextBlock(text='', type='text'), index=0, type='content_block_start')
3: RawContentBlockDeltaEvent(delta=TextDelta(text='Hello', type='text_delta'), index=0, type='content_block_delta')
4: RawContentBlockDeltaEvent(delta=TextDelta(text='!', type='text_delta'), index=0, type='content_block_delta')
5: RawContentBlockDeltaEvent(delta=TextDelta(text=' How', type='text_delta'), index=0, type='content_block_delta')
6: RawContentBlockDeltaEvent(delta=TextDelta(text=' can', type='text_delta'), index=0, type='content_block_delta')
7: RawContentBlockDeltaEvent(delta=TextDelta(text=' I', type='text_delta'), index=0, type='content_block_delta')
8: RawContentBlockDeltaEvent(delta=TextDelta(text=' assist', type='text_delta'), index=0, type='content_block_delta')
9: RawContentBlockDeltaEvent(delta=TextDelta(text=' you', type='text_delta'), index=0, type='content_block_delta')
10: RawContentBlockDeltaEvent(delta=TextDelta(text=' today', type='text_delta'), index=0, type='content_block_delta')
11: RawContentBlockDeltaEvent(delta=TextDelta(text='?', type='text_delta'), index=0, type='content_block_delta')
12: RawContentBlockStopEvent(index=0, type='content_block_stop')
13: RawMessageDeltaEvent(delta=Delta(stop_reason='end_turn', stop_sequence=None), type='message_delta', usage=MessageDeltaUsage(output_tokens=12))
14: RawMessageStopEvent(type='message_stop')
```

Here we implement the second option, in part because it should make
things easier when implementing streaming tool calls in the near future.

This would add two new chunks to the stream-- one at the beginning and
one at the end-- with blank content and containing usage metadata. We
add kwargs to the stream methods and a class attribute allowing for this
behavior to be toggled. I enabled it by default. If we merge this we can
add the same kwargs / attribute to OpenAI.

Usage:
```python
from langchain_anthropic import ChatAnthropic

model = ChatAnthropic(
    model="claude-3-haiku-20240307",
    temperature=0
)

full = None
for chunk in model.stream("hi"):
    full = chunk if full is None else full + chunk
    print(chunk)

print(f"\nFull: {full}")
```
```
content='' id='run-8a20843f-25c7-4025-ad72-9add395899e3' usage_metadata={'input_tokens': 8, 'output_tokens': 0, 'total_tokens': 8}
content='Hello' id='run-8a20843f-25c7-4025-ad72-9add395899e3'
content='!' id='run-8a20843f-25c7-4025-ad72-9add395899e3'
content=' How' id='run-8a20843f-25c7-4025-ad72-9add395899e3'
content=' can' id='run-8a20843f-25c7-4025-ad72-9add395899e3'
content=' I' id='run-8a20843f-25c7-4025-ad72-9add395899e3'
content=' assist' id='run-8a20843f-25c7-4025-ad72-9add395899e3'
content=' you' id='run-8a20843f-25c7-4025-ad72-9add395899e3'
content=' today' id='run-8a20843f-25c7-4025-ad72-9add395899e3'
content='?' id='run-8a20843f-25c7-4025-ad72-9add395899e3'
content='' id='run-8a20843f-25c7-4025-ad72-9add395899e3' usage_metadata={'input_tokens': 0, 'output_tokens': 12, 'total_tokens': 12}

Full: content='Hello! How can I assist you today?' id='run-8a20843f-25c7-4025-ad72-9add395899e3' usage_metadata={'input_tokens': 8, 'output_tokens': 12, 'total_tokens': 20}
```
  • Loading branch information
ccurme authored and hinthornw committed Jun 20, 2024
1 parent a0f55a3 commit 6fa7244
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 11 deletions.
103 changes: 93 additions & 10 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
ToolCall,
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import (
Expand Down Expand Up @@ -446,6 +447,24 @@ class Joke(BaseModel):
{'input_tokens': 25, 'output_tokens': 11, 'total_tokens': 36}
Message chunks containing token usage will be included during streaming by
default:
.. code-block:: python
stream = llm.stream(messages)
full = next(stream)
for chunk in stream:
full += chunk
full.usage_metadata
.. code-block:: python
{'input_tokens': 25, 'output_tokens': 11, 'total_tokens': 36}
These can be disabled by setting ``stream_usage=False`` in the stream method,
or by setting ``stream_usage=False`` when initializing ChatAnthropic.
Response metadata
.. code-block:: python
Expand Down Expand Up @@ -513,6 +532,11 @@ class Config:
streaming: bool = False
"""Whether to use streaming or not."""

stream_usage: bool = True
"""Whether to include usage metadata in streaming output. If True, additional
message chunks will be generated during the stream including usage metadata.
"""

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
Expand Down Expand Up @@ -636,8 +660,12 @@ def _stream(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
*,
stream_usage: Optional[bool] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
if stream_usage is None:
stream_usage = self.stream_usage
params = self._format_params(messages=messages, stop=stop, **kwargs)
if _tools_in_params(params):
result = self._generate(
Expand All @@ -657,25 +685,34 @@ def _stream(
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=message.usage_metadata,
)
yield ChatGenerationChunk(message=message_chunk)
else:
yield cast(ChatGenerationChunk, result.generations[0])
return
with self._client.messages.stream(**params) as stream:
for text in stream.text_stream:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
run_manager.on_llm_new_token(text, chunk=chunk)
stream = self._client.messages.create(**params, stream=True)
for event in stream:
msg = _make_message_chunk_from_anthropic_event(
event, stream_usage=stream_usage
)
if msg is not None:
chunk = ChatGenerationChunk(message=msg)
if run_manager and isinstance(msg.content, str):
run_manager.on_llm_new_token(msg.content, chunk=chunk)
yield chunk

async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
*,
stream_usage: Optional[bool] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
if stream_usage is None:
stream_usage = self.stream_usage
params = self._format_params(messages=messages, stop=stop, **kwargs)
if _tools_in_params(params):
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
Expand All @@ -696,16 +733,21 @@ async def _astream(
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=message.usage_metadata,
)
yield ChatGenerationChunk(message=message_chunk)
else:
yield cast(ChatGenerationChunk, result.generations[0])
return
async with self._async_client.messages.stream(**params) as stream:
async for text in stream.text_stream:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
await run_manager.on_llm_new_token(text, chunk=chunk)
stream = await self._async_client.messages.create(**params, stream=True)
async for event in stream:
msg = _make_message_chunk_from_anthropic_event(
event, stream_usage=stream_usage
)
if msg is not None:
chunk = ChatGenerationChunk(message=msg)
if run_manager and isinstance(msg.content, str):
await run_manager.on_llm_new_token(msg.content, chunk=chunk)
yield chunk

def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
Expand Down Expand Up @@ -1068,6 +1110,47 @@ def _lc_tool_calls_to_anthropic_tool_use_blocks(
return blocks


def _make_message_chunk_from_anthropic_event(
event: anthropic.types.RawMessageStreamEvent,
*,
stream_usage: bool = True,
) -> Optional[AIMessageChunk]:
"""Convert Anthropic event to AIMessageChunk.
Note that not all events will result in a message chunk. In these cases
we return None.
"""
message_chunk: Optional[AIMessageChunk] = None
if event.type == "message_start" and stream_usage:
input_tokens = event.message.usage.input_tokens
message_chunk = AIMessageChunk(
content="",
usage_metadata=UsageMetadata(
input_tokens=input_tokens,
output_tokens=0,
total_tokens=input_tokens,
),
)
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
elif event.type == "content_block_delta" and event.delta.type == "text_delta":
text = event.delta.text
message_chunk = AIMessageChunk(content=text)
elif event.type == "message_delta" and stream_usage:
output_tokens = event.usage.output_tokens
message_chunk = AIMessageChunk(
content="",
usage_metadata=UsageMetadata(
input_tokens=0,
output_tokens=output_tokens,
total_tokens=output_tokens,
),
)
else:
pass

return message_chunk


@deprecated(since="0.1.0", removal="0.3.0", alternative="ChatAnthropic")
class ChatAnthropicMessages(ChatAnthropic):
pass
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Test ChatAnthropic chat model."""

import json
from typing import List
from typing import List, Optional

import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
ToolMessage,
Expand All @@ -28,16 +29,101 @@ def test_stream() -> None:
"""Test streaming tokens from Anthropic."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]

full: Optional[BaseMessageChunk] = None
chunks_with_input_token_counts = 0
chunks_with_output_token_counts = 0
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
full = token if full is None else full + token
assert isinstance(token, AIMessageChunk)
if token.usage_metadata is not None:
if token.usage_metadata.get("input_tokens"):
chunks_with_input_token_counts += 1
elif token.usage_metadata.get("output_tokens"):
chunks_with_output_token_counts += 1
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with input or output token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
assert (
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
== full.usage_metadata["total_tokens"]
)


async def test_astream() -> None:
"""Test streaming tokens from Anthropic."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]

full: Optional[BaseMessageChunk] = None
chunks_with_input_token_counts = 0
chunks_with_output_token_counts = 0
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
full = token if full is None else full + token
assert isinstance(token, AIMessageChunk)
if token.usage_metadata is not None:
if token.usage_metadata.get("input_tokens"):
chunks_with_input_token_counts += 1
elif token.usage_metadata.get("output_tokens"):
chunks_with_output_token_counts += 1
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with input or output token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
assert (
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
== full.usage_metadata["total_tokens"]
)

# test usage metadata can be excluded
model = ChatAnthropic(model_name=MODEL_NAME, stream_usage=False) # type: ignore[call-arg]
async for token in model.astream("hi"):
assert isinstance(token, AIMessageChunk)
assert token.usage_metadata is None
# check we override with kwarg
model = ChatAnthropic(model_name=MODEL_NAME) # type: ignore[call-arg]
assert model.stream_usage
async for token in model.astream("hi", stream_usage=False):
assert isinstance(token, AIMessageChunk)
assert token.usage_metadata is None

# Check expected raw API output
async_client = model._async_client
params: dict = {
"model": "claude-3-haiku-20240307",
"max_tokens": 1024,
"messages": [{"role": "user", "content": "hi"}],
"temperature": 0.0,
}
stream = await async_client.messages.create(**params, stream=True)
async for event in stream:
if event.type == "message_start":
assert event.message.usage.input_tokens > 1
# Note: this single output token included in message start event
# does not appear to contribute to overall output token counts. It
# is excluded from the total token count.
assert event.message.usage.output_tokens == 1
elif event.type == "message_delta":
assert event.usage.output_tokens > 1
else:
pass


async def test_abatch() -> None:
Expand Down

0 comments on commit 6fa7244

Please sign in to comment.