From 70db88733c5b1d8260334fdd97560c930577c62b Mon Sep 17 00:00:00 2001 From: "hgyun.lee" Date: Fri, 24 Oct 2025 21:49:40 +0900 Subject: [PATCH 1/2] Fix ContextFilterPlugin to make explicit context caching work better via an N-sized sliding window --- .../adk/plugins/context_filter_plugin.py | 7 +- .../plugins/test_context_filtering_plugin.py | 108 ++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) diff --git a/src/google/adk/plugins/context_filter_plugin.py b/src/google/adk/plugins/context_filter_plugin.py index b778de02ad..5d62e8f57b 100644 --- a/src/google/adk/plugins/context_filter_plugin.py +++ b/src/google/adk/plugins/context_filter_plugin.py @@ -36,6 +36,8 @@ def __init__( num_invocations_to_keep: Optional[int] = None, custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None, name: str = "context_filter_plugin", + remove_amount: int = 1 + ): """Initializes the context management plugin. @@ -45,10 +47,12 @@ def __init__( by a model response. custom_filter: A function to filter the context. name: The name of the plugin instance. + remove_amount: The amount to remove the context. """ super().__init__(name) self._num_invocations_to_keep = num_invocations_to_keep self._custom_filter = custom_filter + self._remove_amount = remove_amount async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest @@ -60,9 +64,10 @@ async def before_model_callback( if ( self._num_invocations_to_keep is not None and self._num_invocations_to_keep > 0 + and self._remove_amount > 0 ): num_model_turns = sum(1 for c in contents if c.role == "model") - if num_model_turns >= self._num_invocations_to_keep: + if num_model_turns >= self._num_invocations_to_keep + self._remove_amount - 1: model_turns_to_find = self._num_invocations_to_keep split_index = 0 for i in range(len(contents) - 1, -1, -1): diff --git a/tests/unittests/plugins/test_context_filtering_plugin.py b/tests/unittests/plugins/test_context_filtering_plugin.py index f9c8222ea3..974d5a9c4f 100644 --- a/tests/unittests/plugins/test_context_filtering_plugin.py +++ b/tests/unittests/plugins/test_context_filtering_plugin.py @@ -183,3 +183,111 @@ def faulty_filter(contents): ) assert llm_request.contents == original_contents + + +@pytest.mark.asyncio +async def test_filter_with_remove_amount(): + """Tests that remove_amount correctly removes additional invocations.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=2, remove_amount=1) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + _create_content("user", "user_prompt_3"), + _create_content("model", "model_response_3"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + # With num_invocations_to_keep=2 and remove_amount=1, should keep last 2 invocations + assert len(llm_request.contents) == 4 + assert llm_request.contents[0].parts[0].text == "user_prompt_2" + assert llm_request.contents[1].parts[0].text == "model_response_2" + assert llm_request.contents[2].parts[0].text == "user_prompt_3" + assert llm_request.contents[3].parts[0].text == "model_response_3" + + +@pytest.mark.asyncio +async def test_filter_with_higher_remove_amount(): + """Tests remove_amount with a higher value to remove more invocations.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=3, remove_amount=2) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + _create_content("user", "user_prompt_3"), + _create_content("model", "model_response_3"), + _create_content("user", "user_prompt_4"), + _create_content("model", "model_response_4"), + _create_content("user", "user_prompt_5"), + _create_content("model", "model_response_5"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + # With num_invocations_to_keep=3 and remove_amount=2, keeps last 2 invocations + # (num_invocations_to_keep - remove_amount = 1, but the calculation keeps 2) + assert len(llm_request.contents) == 6 + assert llm_request.contents[0].parts[0].text == "user_prompt_3" + assert llm_request.contents[1].parts[0].text == "model_response_3" + assert llm_request.contents[2].parts[0].text == "user_prompt_4" + assert llm_request.contents[3].parts[0].text == "model_response_4" + assert llm_request.contents[4].parts[0].text == "user_prompt_5" + assert llm_request.contents[5].parts[0].text == "model_response_5" + + +@pytest.mark.asyncio +async def test_filter_with_zero_remove_amount(): + """Tests that remove_amount=0 disables the filtering logic.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=1, remove_amount=0) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + ] + llm_request = LlmRequest(contents=contents) + original_contents = list(llm_request.contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + # With remove_amount=0, filtering should be disabled + assert llm_request.contents == original_contents + + +@pytest.mark.asyncio +async def test_filter_remove_amount_with_multiple_user_turns(): + """Tests remove_amount with multiple user turns in invocations.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=2, remove_amount=1) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2a"), + _create_content("user", "user_prompt_2b"), + _create_content("model", "model_response_2"), + _create_content("user", "user_prompt_3"), + _create_content("model", "model_response_3"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + # Should keep last 2 invocations including multiple user turns + assert len(llm_request.contents) == 5 + assert llm_request.contents[0].parts[0].text == "user_prompt_2a" + assert llm_request.contents[1].parts[0].text == "user_prompt_2b" + assert llm_request.contents[2].parts[0].text == "model_response_2" + assert llm_request.contents[3].parts[0].text == "user_prompt_3" + assert llm_request.contents[4].parts[0].text == "model_response_3" From 13a84f41df9d6baf497a776ff547dbeb1275d41e Mon Sep 17 00:00:00 2001 From: "hgyun.lee" Date: Fri, 24 Oct 2025 21:55:40 +0900 Subject: [PATCH 2/2] Apply feedback --- tests/unittests/plugins/test_context_filtering_plugin.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unittests/plugins/test_context_filtering_plugin.py b/tests/unittests/plugins/test_context_filtering_plugin.py index 974d5a9c4f..f8c0c4f99b 100644 --- a/tests/unittests/plugins/test_context_filtering_plugin.py +++ b/tests/unittests/plugins/test_context_filtering_plugin.py @@ -233,8 +233,7 @@ async def test_filter_with_higher_remove_amount(): callback_context=Mock(spec=CallbackContext), llm_request=llm_request ) - # With num_invocations_to_keep=3 and remove_amount=2, keeps last 2 invocations - # (num_invocations_to_keep - remove_amount = 1, but the calculation keeps 2) + # With num_invocations_to_keep=3 and remove_amount=2, keeps last 3 invocations assert len(llm_request.contents) == 6 assert llm_request.contents[0].parts[0].text == "user_prompt_3" assert llm_request.contents[1].parts[0].text == "model_response_3"