From 351a8e68ca3ead4cb0e2e150b0504cce9431fda7 Mon Sep 17 00:00:00 2001 From: supermario_leo Date: Wed, 22 Apr 2026 03:28:42 +0800 Subject: [PATCH] fix(models): forward sampling params and populate finish_reason in AnthropicLlm - Forward temperature, top_p, top_k, stop_sequences from LlmRequest.config to both the non-streaming and streaming Anthropic messages.create calls. Previously these parameters were silently ignored, making it impossible to control generation from the ADK config interface. - Populate finish_reason on LlmResponse for both non-streaming (uncomment the existing helper call) and streaming paths (capture stop_reason from the message_delta event). - Extend to_google_genai_finish_reason() with the two missing mappings: pause_turn -> STOP (extended thinking turn pause) and refusal -> SAFETY (model declined the request). Fixes #5393, #5394 --- src/google/adk/models/anthropic_llm.py | 44 ++- tests/unittests/models/test_anthropic_llm.py | 271 +++++++++++++++++++ 2 files changed, 311 insertions(+), 4 deletions(-) diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index a14c767f23..38db97af02 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -77,10 +77,14 @@ def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]: def to_google_genai_finish_reason( anthropic_stop_reason: Optional[str], ) -> types.FinishReason: - if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]: + if anthropic_stop_reason in [ + "end_turn", "stop_sequence", "tool_use", "pause_turn" + ]: return "STOP" if anthropic_stop_reason == "max_tokens": return "MAX_TOKENS" + if anthropic_stop_reason == "refusal": + return "SAFETY" return "FINISH_REASON_UNSPECIFIED" @@ -253,8 +257,7 @@ def message_to_generate_content_response( message.usage.input_tokens + message.usage.output_tokens ), ), - # TODO: Deal with these later. - # finish_reason=to_google_genai_finish_reason(message.stop_reason), + finish_reason=to_google_genai_finish_reason(message.stop_reason), ) @@ -402,6 +405,22 @@ async def generate_content_async( else NOT_GIVEN ) + config = llm_request.config + temperature = ( + NOT_GIVEN if config is None or config.temperature is None + else config.temperature + ) + top_p = ( + NOT_GIVEN if config is None or config.top_p is None else config.top_p + ) + top_k = ( + NOT_GIVEN if config is None or config.top_k is None else config.top_k + ) + stop_sequences = ( + NOT_GIVEN if not (config and config.stop_sequences) + else config.stop_sequences + ) + if not stream: message = await self._anthropic_client.messages.create( model=model_to_use, @@ -410,11 +429,17 @@ async def generate_content_async( tools=tools, tool_choice=tool_choice, max_tokens=self.max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop_sequences=stop_sequences, ) yield message_to_generate_content_response(message) else: async for response in self._generate_content_streaming( - llm_request, messages, tools, tool_choice + llm_request, messages, tools, tool_choice, + temperature=temperature, top_p=top_p, + top_k=top_k, stop_sequences=stop_sequences, ): yield response @@ -424,6 +449,10 @@ async def _generate_content_streaming( messages: list[anthropic_types.MessageParam], tools: Union[Iterable[anthropic_types.ToolUnionParam], NotGiven], tool_choice: Union[anthropic_types.ToolChoiceParam, NotGiven], + temperature: Union[float, NotGiven] = NOT_GIVEN, + top_p: Union[float, NotGiven] = NOT_GIVEN, + top_k: Union[int, NotGiven] = NOT_GIVEN, + stop_sequences: Union[list[str], NotGiven] = NOT_GIVEN, ) -> AsyncGenerator[LlmResponse, None]: """Handles streaming responses from Anthropic models. @@ -439,6 +468,10 @@ async def _generate_content_streaming( tool_choice=tool_choice, max_tokens=self.max_tokens, stream=True, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop_sequences=stop_sequences, ) # Track content blocks being built during streaming. @@ -447,6 +480,7 @@ async def _generate_content_streaming( tool_use_blocks: dict[int, _ToolUseAccumulator] = {} input_tokens = 0 output_tokens = 0 + stop_reason: Optional[str] = None async for event in raw_stream: if event.type == "message_start": @@ -482,6 +516,7 @@ async def _generate_content_streaming( elif event.type == "message_delta": output_tokens = event.usage.output_tokens + stop_reason = event.delta.stop_reason # Build the final aggregated response with all content. all_parts: list[types.Part] = [] @@ -505,6 +540,7 @@ async def _generate_content_streaming( candidates_token_count=output_tokens, total_token_count=input_tokens + output_tokens, ), + finish_reason=to_google_genai_finish_reason(stop_reason), partial=False, ) diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index fb44d5c8e7..baf604ce15 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -1350,3 +1350,274 @@ async def test_non_streaming_does_not_pass_stream_param(): mock_client.messages.create.assert_called_once() _, kwargs = mock_client.messages.create.call_args assert "stream" not in kwargs + + +# --- Tests for finish_reason population --- + + +def test_to_google_genai_finish_reason_mappings(): + """to_google_genai_finish_reason maps all known stop_reason values.""" + from google.adk.models.anthropic_llm import to_google_genai_finish_reason + + assert to_google_genai_finish_reason("end_turn") == "STOP" + assert to_google_genai_finish_reason("stop_sequence") == "STOP" + assert to_google_genai_finish_reason("tool_use") == "STOP" + assert to_google_genai_finish_reason("pause_turn") == "STOP" + assert to_google_genai_finish_reason("max_tokens") == "MAX_TOKENS" + assert to_google_genai_finish_reason("refusal") == "SAFETY" + assert to_google_genai_finish_reason(None) == "FINISH_REASON_UNSPECIFIED" + assert to_google_genai_finish_reason("unknown_value") == "FINISH_REASON_UNSPECIFIED" + + +def test_message_to_generate_content_response_sets_finish_reason(): + """message_to_generate_content_response populates finish_reason.""" + from google.adk.models.anthropic_llm import message_to_generate_content_response + + message = anthropic_types.Message( + id="msg_test", + content=[anthropic_types.TextBlock(text="Hi!", type="text", citations=None)], + model="claude-sonnet-4-20250514", + role="assistant", + stop_reason="end_turn", + stop_sequence=None, + type="message", + usage=anthropic_types.Usage( + input_tokens=5, + output_tokens=2, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + server_tool_use=None, + service_tier=None, + ), + ) + response = message_to_generate_content_response(message) + assert response.finish_reason == "STOP" + + +def test_message_to_generate_content_response_finish_reason_max_tokens(): + """message_to_generate_content_response maps max_tokens correctly.""" + from google.adk.models.anthropic_llm import message_to_generate_content_response + + message = anthropic_types.Message( + id="msg_test", + content=[anthropic_types.TextBlock(text="...", type="text", citations=None)], + model="claude-sonnet-4-20250514", + role="assistant", + stop_reason="max_tokens", + stop_sequence=None, + type="message", + usage=anthropic_types.Usage( + input_tokens=5, + output_tokens=8192, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + server_tool_use=None, + service_tier=None, + ), + ) + response = message_to_generate_content_response(message) + assert response.finish_reason == "MAX_TOKENS" + + +@pytest.mark.asyncio +async def test_streaming_final_response_has_finish_reason(): + """Streaming final response should include finish_reason from stop_reason.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + events = [ + MagicMock( + type="message_start", + message=MagicMock(usage=MagicMock(input_tokens=5, output_tokens=0)), + ), + MagicMock( + type="content_block_start", + index=0, + content_block=anthropic_types.TextBlock(text="", type="text"), + ), + MagicMock( + type="content_block_delta", + index=0, + delta=anthropic_types.TextDelta(text="Done.", type="text_delta"), + ), + MagicMock(type="content_block_stop", index=0), + MagicMock( + type="message_delta", + delta=MagicMock(stop_reason="max_tokens"), + usage=MagicMock(output_tokens=3), + ), + MagicMock(type="message_stop"), + ] + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock( + return_value=_make_mock_stream_events(events) + ) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig(system_instruction="Test"), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + responses = [ + r async for r in llm.generate_content_async(llm_request, stream=True) + ] + + final = responses[-1] + assert final.partial is False + assert final.finish_reason == "MAX_TOKENS" + + +# --- Tests for sampling parameter forwarding --- + + +@pytest.mark.asyncio +async def test_non_streaming_forwards_sampling_params(): + """Non-streaming generate_content_async forwards temperature/top_p/top_k/stop_sequences.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + mock_message = anthropic_types.Message( + id="msg_test", + content=[ + anthropic_types.TextBlock(text="Hello!", type="text", citations=None) + ], + model="claude-sonnet-4-20250514", + role="assistant", + stop_reason="end_turn", + stop_sequence=None, + type="message", + usage=anthropic_types.Usage( + input_tokens=5, + output_tokens=2, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + server_tool_use=None, + service_tier=None, + ), + ) + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock(return_value=mock_message) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig( + temperature=0.7, + top_p=0.9, + top_k=40, + stop_sequences=["END", "STOP"], + system_instruction="Be helpful", + ), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + _ = [r async for r in llm.generate_content_async(llm_request, stream=False)] + + mock_client.messages.create.assert_called_once() + _, kwargs = mock_client.messages.create.call_args + assert kwargs["temperature"] == 0.7 + assert kwargs["top_p"] == 0.9 + assert kwargs["top_k"] == 40 + assert kwargs["stop_sequences"] == ["END", "STOP"] + + +@pytest.mark.asyncio +async def test_streaming_forwards_sampling_params(): + """Streaming generate_content_async forwards temperature/top_p/top_k/stop_sequences.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + events = [ + MagicMock( + type="message_start", + message=MagicMock(usage=MagicMock(input_tokens=5, output_tokens=0)), + ), + MagicMock(type="content_block_start", index=0, + content_block=anthropic_types.TextBlock(text="", type="text")), + MagicMock( + type="content_block_delta", index=0, + delta=anthropic_types.TextDelta(text="Hi", type="text_delta"), + ), + MagicMock(type="content_block_stop", index=0), + MagicMock( + type="message_delta", + delta=MagicMock(stop_reason="end_turn"), + usage=MagicMock(output_tokens=1), + ), + MagicMock(type="message_stop"), + ] + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock( + return_value=_make_mock_stream_events(events) + ) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig( + temperature=0.3, + top_p=0.95, + top_k=10, + stop_sequences=["<|END|>"], + system_instruction="Test", + ), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + _ = [r async for r in llm.generate_content_async(llm_request, stream=True)] + + mock_client.messages.create.assert_called_once() + _, kwargs = mock_client.messages.create.call_args + assert kwargs["temperature"] == 0.3 + assert kwargs["top_p"] == 0.95 + assert kwargs["top_k"] == 10 + assert kwargs["stop_sequences"] == ["<|END|>"] + assert kwargs["stream"] is True + + +@pytest.mark.asyncio +async def test_sampling_params_use_not_given_when_absent(): + """Sampling params default to NOT_GIVEN when absent from config.""" + from anthropic import NOT_GIVEN + + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + mock_message = anthropic_types.Message( + id="msg_test", + content=[ + anthropic_types.TextBlock(text="Hi!", type="text", citations=None) + ], + model="claude-sonnet-4-20250514", + role="assistant", + stop_reason="end_turn", + stop_sequence=None, + type="message", + usage=anthropic_types.Usage( + input_tokens=5, + output_tokens=2, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + server_tool_use=None, + service_tier=None, + ), + ) + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock(return_value=mock_message) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig(system_instruction="Test"), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + _ = [r async for r in llm.generate_content_async(llm_request, stream=False)] + + _, kwargs = mock_client.messages.create.call_args + assert kwargs["temperature"] is NOT_GIVEN + assert kwargs["top_p"] is NOT_GIVEN + assert kwargs["top_k"] is NOT_GIVEN + assert kwargs["stop_sequences"] is NOT_GIVEN