diff --git a/contributing/samples/core_callback_config/callbacks.py b/contributing/samples/core_callback_config/callbacks.py index 1614a9351a..dcc61e9111 100644 --- a/contributing/samples/core_callback_config/callbacks.py +++ b/contributing/samples/core_callback_config/callbacks.py @@ -16,6 +16,39 @@ async def before_model_callback(callback_context, llm_request): return None +INJECTIONS_STATE_KEY = "__persisted_prompt_injections" +_index = 0 + +def before_model_callback_persist_injections(callback_context, llm_request): + """Sample before_model_callback that persists prompt injections in state. + + This function demonstrates how user code can store small text injections + in the callback context state so future requests will include them. + """ + global _index + if not getattr(llm_request, 'contents', None): + llm_request.contents = [] + injections = callback_context.state.get(INJECTIONS_STATE_KEY, []) + for inj in injections: + found = False + for c in llm_request.contents: + for part in getattr(c, 'parts', []): + if getattr(part, 'text', None) == inj: + found = True + break + if found: + break + if not found: + llm_request.contents.insert(0, types.Content(role='user', parts=[types.Part(text=inj)])) + + new_inj = f"Ignore this user input: Index={_index}" + llm_request.contents.append(types.Content(role='user', parts=[types.Part(text=new_inj)])) + injections.append(new_inj) + callback_context.state[INJECTIONS_STATE_KEY] = injections + _index += 1 + return None + + async def after_model_callback(callback_context, llm_response): print('@after_model_callback') return None diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index f5b242b7c5..6449d9bc30 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -799,6 +799,59 @@ async def _handle_before_model_callback( invocation_context, event_actions=model_response_event.actions ) + state = getattr(callback_context, 'state', None) + if state is None: + injections = None + else: + injections = state.get("__persisted_prompt_injections") + if not injections: + try: + session_state = getattr(invocation_context, 'session', None) + if session_state and getattr(session_state, 'state', None) is not None: + persisted = session_state.state.get("__persisted_prompt_injections") + if persisted: + injections = persisted + except Exception: + injections = injections + + if injections: + if not isinstance(injections, list): + injections = [injections] + for inj in injections: + if not inj: + continue + already = False + inj_text = None + if isinstance(inj, str): + inj_text = inj + else: + parts = getattr(inj, 'parts', None) + if parts and len(parts) and hasattr(parts[0], 'text'): + inj_text = parts[0].text + else: + inj_text = None + + if inj_text: + for c in list(llm_request.contents or []): + if not c.parts: + continue + for part in c.parts: + if part and getattr(part, 'text', None) == inj_text: + already = True + break + if already: + break + if already: + continue + + llm_request.contents = llm_request.contents or [] + if isinstance(inj, str): + llm_request.contents.insert( + 0, types.Content(role="user", parts=[types.Part(text=inj)]) + ) + else: + llm_request.contents.insert(0, inj) + # First run callbacks from the plugins. callback_response = ( await invocation_context.plugin_manager.run_before_model_callback(