Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions src/google/adk/models/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,14 @@ def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]:
def to_google_genai_finish_reason(
anthropic_stop_reason: Optional[str],
) -> types.FinishReason:
if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]:
if anthropic_stop_reason in [
"end_turn", "stop_sequence", "tool_use", "pause_turn"
]:
return "STOP"
if anthropic_stop_reason == "max_tokens":
return "MAX_TOKENS"
if anthropic_stop_reason == "refusal":
return "SAFETY"
return "FINISH_REASON_UNSPECIFIED"


Expand Down Expand Up @@ -253,8 +257,7 @@ def message_to_generate_content_response(
message.usage.input_tokens + message.usage.output_tokens
),
),
# TODO: Deal with these later.
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
finish_reason=to_google_genai_finish_reason(message.stop_reason),
)


Expand Down Expand Up @@ -402,6 +405,22 @@ async def generate_content_async(
else NOT_GIVEN
)

config = llm_request.config
temperature = (
NOT_GIVEN if config is None or config.temperature is None
else config.temperature
)
top_p = (
NOT_GIVEN if config is None or config.top_p is None else config.top_p
)
top_k = (
NOT_GIVEN if config is None or config.top_k is None else config.top_k
)
stop_sequences = (
NOT_GIVEN if not (config and config.stop_sequences)
else config.stop_sequences
)

if not stream:
message = await self._anthropic_client.messages.create(
model=model_to_use,
Expand All @@ -410,11 +429,17 @@ async def generate_content_async(
tools=tools,
tool_choice=tool_choice,
max_tokens=self.max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stop_sequences=stop_sequences,
)
yield message_to_generate_content_response(message)
else:
async for response in self._generate_content_streaming(
llm_request, messages, tools, tool_choice
llm_request, messages, tools, tool_choice,
temperature=temperature, top_p=top_p,
top_k=top_k, stop_sequences=stop_sequences,
):
yield response

Expand All @@ -424,6 +449,10 @@ async def _generate_content_streaming(
messages: list[anthropic_types.MessageParam],
tools: Union[Iterable[anthropic_types.ToolUnionParam], NotGiven],
tool_choice: Union[anthropic_types.ToolChoiceParam, NotGiven],
temperature: Union[float, NotGiven] = NOT_GIVEN,
top_p: Union[float, NotGiven] = NOT_GIVEN,
top_k: Union[int, NotGiven] = NOT_GIVEN,
stop_sequences: Union[list[str], NotGiven] = NOT_GIVEN,
) -> AsyncGenerator[LlmResponse, None]:
"""Handles streaming responses from Anthropic models.

Expand All @@ -439,6 +468,10 @@ async def _generate_content_streaming(
tool_choice=tool_choice,
max_tokens=self.max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stop_sequences=stop_sequences,
)

# Track content blocks being built during streaming.
Expand All @@ -447,6 +480,7 @@ async def _generate_content_streaming(
tool_use_blocks: dict[int, _ToolUseAccumulator] = {}
input_tokens = 0
output_tokens = 0
stop_reason: Optional[str] = None

async for event in raw_stream:
if event.type == "message_start":
Expand Down Expand Up @@ -482,6 +516,7 @@ async def _generate_content_streaming(

elif event.type == "message_delta":
output_tokens = event.usage.output_tokens
stop_reason = event.delta.stop_reason

# Build the final aggregated response with all content.
all_parts: list[types.Part] = []
Expand All @@ -505,6 +540,7 @@ async def _generate_content_streaming(
candidates_token_count=output_tokens,
total_token_count=input_tokens + output_tokens,
),
finish_reason=to_google_genai_finish_reason(stop_reason),
partial=False,
)

Expand Down
271 changes: 271 additions & 0 deletions tests/unittests/models/test_anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,3 +1350,274 @@ async def test_non_streaming_does_not_pass_stream_param():
mock_client.messages.create.assert_called_once()
_, kwargs = mock_client.messages.create.call_args
assert "stream" not in kwargs


# --- Tests for finish_reason population ---


def test_to_google_genai_finish_reason_mappings():
"""to_google_genai_finish_reason maps all known stop_reason values."""
from google.adk.models.anthropic_llm import to_google_genai_finish_reason

assert to_google_genai_finish_reason("end_turn") == "STOP"
assert to_google_genai_finish_reason("stop_sequence") == "STOP"
assert to_google_genai_finish_reason("tool_use") == "STOP"
assert to_google_genai_finish_reason("pause_turn") == "STOP"
assert to_google_genai_finish_reason("max_tokens") == "MAX_TOKENS"
assert to_google_genai_finish_reason("refusal") == "SAFETY"
assert to_google_genai_finish_reason(None) == "FINISH_REASON_UNSPECIFIED"
assert to_google_genai_finish_reason("unknown_value") == "FINISH_REASON_UNSPECIFIED"


def test_message_to_generate_content_response_sets_finish_reason():
"""message_to_generate_content_response populates finish_reason."""
from google.adk.models.anthropic_llm import message_to_generate_content_response

message = anthropic_types.Message(
id="msg_test",
content=[anthropic_types.TextBlock(text="Hi!", type="text", citations=None)],
model="claude-sonnet-4-20250514",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
type="message",
usage=anthropic_types.Usage(
input_tokens=5,
output_tokens=2,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
server_tool_use=None,
service_tier=None,
),
)
response = message_to_generate_content_response(message)
assert response.finish_reason == "STOP"


def test_message_to_generate_content_response_finish_reason_max_tokens():
"""message_to_generate_content_response maps max_tokens correctly."""
from google.adk.models.anthropic_llm import message_to_generate_content_response

message = anthropic_types.Message(
id="msg_test",
content=[anthropic_types.TextBlock(text="...", type="text", citations=None)],
model="claude-sonnet-4-20250514",
role="assistant",
stop_reason="max_tokens",
stop_sequence=None,
type="message",
usage=anthropic_types.Usage(
input_tokens=5,
output_tokens=8192,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
server_tool_use=None,
service_tier=None,
),
)
response = message_to_generate_content_response(message)
assert response.finish_reason == "MAX_TOKENS"


@pytest.mark.asyncio
async def test_streaming_final_response_has_finish_reason():
"""Streaming final response should include finish_reason from stop_reason."""
llm = AnthropicLlm(model="claude-sonnet-4-20250514")

events = [
MagicMock(
type="message_start",
message=MagicMock(usage=MagicMock(input_tokens=5, output_tokens=0)),
),
MagicMock(
type="content_block_start",
index=0,
content_block=anthropic_types.TextBlock(text="", type="text"),
),
MagicMock(
type="content_block_delta",
index=0,
delta=anthropic_types.TextDelta(text="Done.", type="text_delta"),
),
MagicMock(type="content_block_stop", index=0),
MagicMock(
type="message_delta",
delta=MagicMock(stop_reason="max_tokens"),
usage=MagicMock(output_tokens=3),
),
MagicMock(type="message_stop"),
]

mock_client = MagicMock()
mock_client.messages.create = AsyncMock(
return_value=_make_mock_stream_events(events)
)

llm_request = LlmRequest(
model="claude-sonnet-4-20250514",
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])],
config=types.GenerateContentConfig(system_instruction="Test"),
)

with mock.patch.object(llm, "_anthropic_client", mock_client):
responses = [
r async for r in llm.generate_content_async(llm_request, stream=True)
]

final = responses[-1]
assert final.partial is False
assert final.finish_reason == "MAX_TOKENS"


# --- Tests for sampling parameter forwarding ---


@pytest.mark.asyncio
async def test_non_streaming_forwards_sampling_params():
"""Non-streaming generate_content_async forwards temperature/top_p/top_k/stop_sequences."""
llm = AnthropicLlm(model="claude-sonnet-4-20250514")

mock_message = anthropic_types.Message(
id="msg_test",
content=[
anthropic_types.TextBlock(text="Hello!", type="text", citations=None)
],
model="claude-sonnet-4-20250514",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
type="message",
usage=anthropic_types.Usage(
input_tokens=5,
output_tokens=2,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
server_tool_use=None,
service_tier=None,
),
)

mock_client = MagicMock()
mock_client.messages.create = AsyncMock(return_value=mock_message)

llm_request = LlmRequest(
model="claude-sonnet-4-20250514",
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])],
config=types.GenerateContentConfig(
temperature=0.7,
top_p=0.9,
top_k=40,
stop_sequences=["END", "STOP"],
system_instruction="Be helpful",
),
)

with mock.patch.object(llm, "_anthropic_client", mock_client):
_ = [r async for r in llm.generate_content_async(llm_request, stream=False)]

mock_client.messages.create.assert_called_once()
_, kwargs = mock_client.messages.create.call_args
assert kwargs["temperature"] == 0.7
assert kwargs["top_p"] == 0.9
assert kwargs["top_k"] == 40
assert kwargs["stop_sequences"] == ["END", "STOP"]


@pytest.mark.asyncio
async def test_streaming_forwards_sampling_params():
"""Streaming generate_content_async forwards temperature/top_p/top_k/stop_sequences."""
llm = AnthropicLlm(model="claude-sonnet-4-20250514")

events = [
MagicMock(
type="message_start",
message=MagicMock(usage=MagicMock(input_tokens=5, output_tokens=0)),
),
MagicMock(type="content_block_start", index=0,
content_block=anthropic_types.TextBlock(text="", type="text")),
MagicMock(
type="content_block_delta", index=0,
delta=anthropic_types.TextDelta(text="Hi", type="text_delta"),
),
MagicMock(type="content_block_stop", index=0),
MagicMock(
type="message_delta",
delta=MagicMock(stop_reason="end_turn"),
usage=MagicMock(output_tokens=1),
),
MagicMock(type="message_stop"),
]

mock_client = MagicMock()
mock_client.messages.create = AsyncMock(
return_value=_make_mock_stream_events(events)
)

llm_request = LlmRequest(
model="claude-sonnet-4-20250514",
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])],
config=types.GenerateContentConfig(
temperature=0.3,
top_p=0.95,
top_k=10,
stop_sequences=["<|END|>"],
system_instruction="Test",
),
)

with mock.patch.object(llm, "_anthropic_client", mock_client):
_ = [r async for r in llm.generate_content_async(llm_request, stream=True)]

mock_client.messages.create.assert_called_once()
_, kwargs = mock_client.messages.create.call_args
assert kwargs["temperature"] == 0.3
assert kwargs["top_p"] == 0.95
assert kwargs["top_k"] == 10
assert kwargs["stop_sequences"] == ["<|END|>"]
assert kwargs["stream"] is True


@pytest.mark.asyncio
async def test_sampling_params_use_not_given_when_absent():
"""Sampling params default to NOT_GIVEN when absent from config."""
from anthropic import NOT_GIVEN

llm = AnthropicLlm(model="claude-sonnet-4-20250514")

mock_message = anthropic_types.Message(
id="msg_test",
content=[
anthropic_types.TextBlock(text="Hi!", type="text", citations=None)
],
model="claude-sonnet-4-20250514",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
type="message",
usage=anthropic_types.Usage(
input_tokens=5,
output_tokens=2,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
server_tool_use=None,
service_tier=None,
),
)

mock_client = MagicMock()
mock_client.messages.create = AsyncMock(return_value=mock_message)

llm_request = LlmRequest(
model="claude-sonnet-4-20250514",
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])],
config=types.GenerateContentConfig(system_instruction="Test"),
)

with mock.patch.object(llm, "_anthropic_client", mock_client):
_ = [r async for r in llm.generate_content_async(llm_request, stream=False)]

_, kwargs = mock_client.messages.create.call_args
assert kwargs["temperature"] is NOT_GIVEN
assert kwargs["top_p"] is NOT_GIVEN
assert kwargs["top_k"] is NOT_GIVEN
assert kwargs["stop_sequences"] is NOT_GIVEN