Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions py/src/braintrust/integrations/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,24 @@ def _get_model_name_from_response(response: LLMResult) -> str | None:
return model_name


def _cache_tokens_are_separate_from_input_tokens(input_token_details: dict[str, Any]) -> bool:
# LangChain provider packages use different cache-token conventions:
# - OpenAI-style responses report cache reads as a subset of input_tokens.
# - Anthropic-style responses report cache reads/creation separately from input_tokens.
#
# Avoid provider-name checks here so any LangChain integration using the same
# "separate cache tokens" schema gets normalized, while providers that only
# expose cache_read as input-token detail do not get double-counted.
return any(
key in input_token_details
for key in (
"cache_creation",
"ephemeral_5m_input_tokens",
"ephemeral_1h_input_tokens",
)
)


def _get_metrics_from_response(response: LLMResult):
metrics = {}

Expand Down Expand Up @@ -646,10 +664,14 @@ def _get_metrics_from_response(response: LLMResult):
# langchain-anthropic >= 1.4.0 maps cache_creation_input_tokens to
# ephemeral tier fields (ephemeral_5m_input_tokens, ephemeral_1h_input_tokens)
# rather than the top-level cache_creation field. Sum both for compat.
cache_creation = input_token_details.get("cache_creation") or (
input_token_details.get("ephemeral_5m_input_tokens", 0)
+ input_token_details.get("ephemeral_1h_input_tokens", 0)
)
cache_creation = input_token_details.get("cache_creation")
if not cache_creation and (
"ephemeral_5m_input_tokens" in input_token_details
or "ephemeral_1h_input_tokens" in input_token_details
):
cache_creation = input_token_details.get("ephemeral_5m_input_tokens", 0) + input_token_details.get(
"ephemeral_1h_input_tokens", 0
)

if cache_read is not None:
metrics["prompt_cached_tokens"] = cache_read
Expand All @@ -665,6 +687,7 @@ def _get_metrics_from_response(response: LLMResult):
and prompt_tokens is not None
and completion_tokens is not None
and total_tokens == prompt_tokens + completion_tokens
and _cache_tokens_are_separate_from_input_tokens(input_token_details)
):
metrics["prompt_tokens"] = prompt_tokens + cache_tokens
metrics["total_tokens"] = total_tokens + cache_tokens
Expand Down
30 changes: 30 additions & 0 deletions py/src/braintrust/integrations/langchain/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import pytest
from braintrust import logger
from braintrust.integrations.langchain import BraintrustCallbackHandler
from braintrust.integrations.langchain.callbacks import _get_metrics_from_response
from braintrust.logger import flush
from braintrust.test_helpers import init_test_logger
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables import RunnableMap, RunnableSerializable
Expand Down Expand Up @@ -906,6 +908,34 @@ def test_streaming_ttft(logger_memory_logger):
)


def test_openai_cached_tokens_are_not_folded_into_prompt_tokens():
response = LLMResult(
generations=[
[
ChatGeneration(
message=AIMessage(
content="Done",
response_metadata={"model_name": "gpt-4o-mini-2024-07-18"},
usage_metadata={
"input_tokens": 1000,
"output_tokens": 200,
"total_tokens": 1200,
"input_token_details": {"cache_read": 500},
},
)
)
]
]
)

assert _get_metrics_from_response(response) == {
"prompt_tokens": 1000,
"completion_tokens": 200,
"total_tokens": 1200,
"prompt_cached_tokens": 500,
}


@pytest.mark.vcr
def test_prompt_caching_tokens(logger_memory_logger):
from langchain_anthropic import ChatAnthropic
Expand Down
Loading