Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/google/adk/plugins/context_filter_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __init__(
num_invocations_to_keep: Optional[int] = None,
custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None,
name: str = "context_filter_plugin",
remove_amount: int = 1

):
"""Initializes the context management plugin.

Expand All @@ -45,10 +47,12 @@ def __init__(
by a model response.
custom_filter: A function to filter the context.
name: The name of the plugin instance.
remove_amount: The amount to remove the context.
"""
super().__init__(name)
self._num_invocations_to_keep = num_invocations_to_keep
self._custom_filter = custom_filter
self._remove_amount = remove_amount

async def before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
Expand All @@ -60,9 +64,10 @@ async def before_model_callback(
if (
self._num_invocations_to_keep is not None
and self._num_invocations_to_keep > 0
and self._remove_amount > 0
):
num_model_turns = sum(1 for c in contents if c.role == "model")
if num_model_turns >= self._num_invocations_to_keep:
if num_model_turns >= self._num_invocations_to_keep + self._remove_amount - 1:
model_turns_to_find = self._num_invocations_to_keep
split_index = 0
for i in range(len(contents) - 1, -1, -1):
Expand Down
107 changes: 107 additions & 0 deletions tests/unittests/plugins/test_context_filtering_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,110 @@ def faulty_filter(contents):
)

assert llm_request.contents == original_contents


@pytest.mark.asyncio
async def test_filter_with_remove_amount():
"""Tests that remove_amount correctly removes additional invocations."""
plugin = ContextFilterPlugin(num_invocations_to_keep=2, remove_amount=1)
contents = [
_create_content("user", "user_prompt_1"),
_create_content("model", "model_response_1"),
_create_content("user", "user_prompt_2"),
_create_content("model", "model_response_2"),
_create_content("user", "user_prompt_3"),
_create_content("model", "model_response_3"),
]
llm_request = LlmRequest(contents=contents)

await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
)

# With num_invocations_to_keep=2 and remove_amount=1, should keep last 2 invocations
assert len(llm_request.contents) == 4
assert llm_request.contents[0].parts[0].text == "user_prompt_2"
assert llm_request.contents[1].parts[0].text == "model_response_2"
assert llm_request.contents[2].parts[0].text == "user_prompt_3"
assert llm_request.contents[3].parts[0].text == "model_response_3"


@pytest.mark.asyncio
async def test_filter_with_higher_remove_amount():
"""Tests remove_amount with a higher value to remove more invocations."""
plugin = ContextFilterPlugin(num_invocations_to_keep=3, remove_amount=2)
contents = [
_create_content("user", "user_prompt_1"),
_create_content("model", "model_response_1"),
_create_content("user", "user_prompt_2"),
_create_content("model", "model_response_2"),
_create_content("user", "user_prompt_3"),
_create_content("model", "model_response_3"),
_create_content("user", "user_prompt_4"),
_create_content("model", "model_response_4"),
_create_content("user", "user_prompt_5"),
_create_content("model", "model_response_5"),
]
llm_request = LlmRequest(contents=contents)

await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
)

# With num_invocations_to_keep=3 and remove_amount=2, keeps last 3 invocations
assert len(llm_request.contents) == 6
assert llm_request.contents[0].parts[0].text == "user_prompt_3"
assert llm_request.contents[1].parts[0].text == "model_response_3"
assert llm_request.contents[2].parts[0].text == "user_prompt_4"
assert llm_request.contents[3].parts[0].text == "model_response_4"
assert llm_request.contents[4].parts[0].text == "user_prompt_5"
assert llm_request.contents[5].parts[0].text == "model_response_5"


@pytest.mark.asyncio
async def test_filter_with_zero_remove_amount():
"""Tests that remove_amount=0 disables the filtering logic."""
plugin = ContextFilterPlugin(num_invocations_to_keep=1, remove_amount=0)
contents = [
_create_content("user", "user_prompt_1"),
_create_content("model", "model_response_1"),
_create_content("user", "user_prompt_2"),
_create_content("model", "model_response_2"),
]
llm_request = LlmRequest(contents=contents)
original_contents = list(llm_request.contents)

await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
)

# With remove_amount=0, filtering should be disabled
assert llm_request.contents == original_contents


@pytest.mark.asyncio
async def test_filter_remove_amount_with_multiple_user_turns():
"""Tests remove_amount with multiple user turns in invocations."""
plugin = ContextFilterPlugin(num_invocations_to_keep=2, remove_amount=1)
contents = [
_create_content("user", "user_prompt_1"),
_create_content("model", "model_response_1"),
_create_content("user", "user_prompt_2a"),
_create_content("user", "user_prompt_2b"),
_create_content("model", "model_response_2"),
_create_content("user", "user_prompt_3"),
_create_content("model", "model_response_3"),
]
llm_request = LlmRequest(contents=contents)

await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
)

# Should keep last 2 invocations including multiple user turns
assert len(llm_request.contents) == 5
assert llm_request.contents[0].parts[0].text == "user_prompt_2a"
assert llm_request.contents[1].parts[0].text == "user_prompt_2b"
assert llm_request.contents[2].parts[0].text == "model_response_2"
assert llm_request.contents[3].parts[0].text == "user_prompt_3"
assert llm_request.contents[4].parts[0].text == "model_response_3"