Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class UsageMetadataChunk(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
# Cached tokens served from provider cache for the prompt portion
cached_prompt_tokens: int = 0


class LiteLLMClient:
Expand Down Expand Up @@ -154,6 +156,46 @@ def _safe_json_serialize(obj) -> str:
return str(obj)


def _extract_cached_prompt_tokens(usage: dict[str, Any]) -> int:
"""Best-effort extraction of cached prompt tokens from LiteLLM usage.

Providers expose cached token metrics in different shapes. Common patterns:
- usage["prompt_tokens_details"]["cached_tokens"] (OpenAI/Azure style)
- usage["prompt_tokens_details"] is a list of dicts with cached_tokens
- usage["cached_prompt_tokens"] (LiteLLM-normalized for some providers)
- usage["cached_tokens"] (flat)

Args:
usage: Usage dictionary from LiteLLM response.

Returns:
Integer number of cached prompt tokens if present; otherwise 0.
"""

try:
details = usage.get("prompt_tokens_details")
if isinstance(details, dict):
value = details.get("cached_tokens")
if isinstance(value, int):
return value
elif isinstance(details, list):
total = 0
for item in details:
if isinstance(item, dict) and isinstance(item.get("cached_tokens"), int):
total += item["cached_tokens"]
if total:
return total

for key in ("cached_prompt_tokens", "cached_tokens"):
value = usage.get(key)
if isinstance(value, int):
return value
except Exception: # noqa: BLE001 - defensive: provider-specific shapes vary
pass

return 0


def _content_to_message_param(
content: types.Content,
) -> Union[Message, list[Message]]:
Expand Down Expand Up @@ -463,10 +505,12 @@ def _model_response_to_chunk(
# finish_reason set. But this is not the case we are observing from litellm.
# So we are sending it as a separate chunk to be set on the llm_response.
if response.get("usage", None):
usage_dict = response["usage"]
yield UsageMetadataChunk(
prompt_tokens=response["usage"].get("prompt_tokens", 0),
completion_tokens=response["usage"].get("completion_tokens", 0),
total_tokens=response["usage"].get("total_tokens", 0),
prompt_tokens=usage_dict.get("prompt_tokens", 0),
completion_tokens=usage_dict.get("completion_tokens", 0),
total_tokens=usage_dict.get("total_tokens", 0),
cached_prompt_tokens=_extract_cached_prompt_tokens(usage_dict),
), None


Expand All @@ -491,10 +535,12 @@ def _model_response_to_generate_content_response(

llm_response = _message_to_generate_content_response(message)
if response.get("usage", None):
usage_dict = response["usage"]
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=response["usage"].get("prompt_tokens", 0),
candidates_token_count=response["usage"].get("completion_tokens", 0),
total_token_count=response["usage"].get("total_tokens", 0),
prompt_token_count=usage_dict.get("prompt_tokens", 0),
candidates_token_count=usage_dict.get("completion_tokens", 0),
total_token_count=usage_dict.get("total_tokens", 0),
cached_content_token_count=_extract_cached_prompt_tokens(usage_dict),
)
return llm_response

Expand Down Expand Up @@ -874,6 +920,7 @@ async def generate_content_async(
prompt_token_count=chunk.prompt_tokens,
candidates_token_count=chunk.completion_tokens,
total_token_count=chunk.total_tokens,
cached_content_token_count=chunk.cached_prompt_tokens,
)

if (
Expand Down
82 changes: 82 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,88 @@ def test_model_response_to_chunk(response, expected_chunks, expected_finished):
assert usage_chunk.total_tokens == expected_chunks[1].total_tokens


@pytest.mark.asyncio
async def test_generate_content_async_with_cached_tokens_non_stream(
lite_llm_instance, mock_acompletion
):
# Simulate LiteLLM usage shapes that include cached tokens
mock_response_with_cached_usage = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="Test response",
)
)
],
usage={
"prompt_tokens": 2100,
"completion_tokens": 50,
"total_tokens": 2150,
# Common provider shapes
"prompt_tokens_details": {"cached_tokens": 1800},
},
)
mock_acompletion.return_value = mock_response_with_cached_usage

llm_request = LlmRequest(
contents=[
types.Content(role="user", parts=[types.Part.from_text(text="q")])
]
)

results = [
r async for r in lite_llm_instance.generate_content_async(llm_request)
]
assert len(results) == 1
resp = results[0]
assert resp.usage_metadata is not None
assert resp.usage_metadata.prompt_token_count == 2100
assert resp.usage_metadata.candidates_token_count == 50
assert resp.usage_metadata.total_token_count == 2150
# Key assertion: cached_content_token_count is propagated
assert resp.usage_metadata.cached_content_token_count == 1800


@pytest.mark.asyncio
async def test_generate_content_async_with_cached_tokens_stream(
mock_completion, lite_llm_instance
):
# Build a stream with final usage chunk that includes cached tokens
streaming_with_cached_usage = [
*STREAMING_MODEL_RESPONSE,
ModelResponse(
usage={
"prompt_tokens": 2100,
"completion_tokens": 50,
"total_tokens": 2150,
# Alternative flattened shape
"cached_prompt_tokens": 1700,
},
choices=[StreamingChoices(finish_reason=None)],
),
]
mock_completion.return_value = iter(streaming_with_cached_usage)

llm_request = LlmRequest(
contents=[
types.Content(role="user", parts=[types.Part.from_text(text="q")])
]
)
responses = [
r async for r in lite_llm_instance.generate_content_async(
llm_request, stream=True
)
]
# Final aggregated response carries usage
assert len(responses) == 4
final_resp = responses[-1]
assert final_resp.usage_metadata is not None
assert final_resp.usage_metadata.prompt_token_count == 2100
assert final_resp.usage_metadata.total_token_count == 2150
assert final_resp.usage_metadata.cached_content_token_count == 1700


@pytest.mark.asyncio
async def test_acompletion_additional_args(mock_acompletion, mock_client):
lite_llm_instance = LiteLlm(
Expand Down