Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 58 additions & 20 deletions src/google/adk/utils/instructions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
import re
from typing import Any

from ..agents.readonly_context import ReadonlyContext
from ..sessions.state import State
Expand Down Expand Up @@ -46,7 +47,11 @@ async def build_instruction(
) -> str:
return await inject_session_state(
'You can inject a state variable like {var_name} or an artifact '
'{artifact.file_name} into the instruction template.',
'{artifact.file_name} into the instruction template.'
'You can also inject a nested variable like {var_name.nested_var}.'
'If a variable or nested attribute may be missing, append `?` to the '
'path or attribute name for optional handling, e.g. '
'{var_name.optional_nested_var?}.',
readonly_context,
)

Expand Down Expand Up @@ -78,14 +83,52 @@ async def _async_sub(pattern, repl_async_fn, string) -> str:
result.append(string[last_end:])
return ''.join(result)

def _get_nested_value(obj: Any, path: str) -> Any:
"""Retrieve nested value from an object based on dot-separated path."""
parts = path.split('.')
current = obj

for part in parts:
if current is None:
return None

optional = part.endswith('?')
key = part[:-1] if optional else part
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve consistency across the file for stripping the optional marker, you could use removesuffix('?'). This method is already used on line 150 and is generally safer and more readable than slicing.

Suggested change
key = part[:-1] if optional else part
key = part.removesuffix('?')


# Try dictionary access first
if hasattr(current, '__getitem__'):
try:
current = current[key]
continue
except (KeyError, TypeError):
# If dict access fails, fall through to try getattr
# UNLESS it's a pure dict which definitely doesn't have attributes
if isinstance(current, dict):
if optional:
return None
raise KeyError(f"Key '{key}' not found in path '{path}'")
pass

# Try attribute access
try:
current = getattr(current, key)
except AttributeError:
# Both dict access and attribute access failed.
if optional:
return None
raise KeyError(f"Key '{key}' not found in path '{path}'")

return current

async def _replace_match(match) -> str:
var_name = match.group().lstrip('{').rstrip('}').strip()
optional = False
if var_name.endswith('?'):
optional = True
var_name = var_name.removesuffix('?')
if var_name.startswith('artifact.'):
var_name = var_name.removeprefix('artifact.')
full_path = match.group().lstrip('{').rstrip('}').strip()

if full_path.startswith('artifact.'):
var_name = full_path.removeprefix('artifact.')
optional = var_name.endswith('?')
if optional:
var_name = var_name[:-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with how optional markers are handled for state variables on line 150, consider using removesuffix('?') here as well. It's slightly more descriptive than slicing.

Suggested change
var_name = var_name[:-1]
var_name = var_name.removesuffix('?')


if invocation_context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
artifact = await invocation_context.artifact_service.load_artifact(
Expand All @@ -104,22 +147,17 @@ async def _replace_match(match) -> str:
raise KeyError(f'Artifact {var_name} not found.')
return str(artifact)
else:
if not _is_valid_state_name(var_name):
if not _is_valid_state_name(full_path.split('.')[0].removesuffix('?')):
return match.group()
if var_name in invocation_context.session.state:
value = invocation_context.session.state[var_name]

try:
value = _get_nested_value(invocation_context.session.state, full_path)

if value is None:
return ''
return str(value)
else:
if optional:
logger.debug(
'Context variable %s not found, replacing with empty string',
var_name,
)
return ''
else:
raise KeyError(f'Context variable not found: `{var_name}`.')
except KeyError as e:
raise KeyError(f'Context variable not found: `{full_path}`.') from e

return await _async_sub(r'{+[^{}]*}+', _replace_match, template)

Expand Down
215 changes: 215 additions & 0 deletions tests/unittests/utils/test_instructions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,218 @@ async def test_inject_session_state_with_optional_missing_state_returns_empty():
instruction_template, invocation_context
)
assert populated_instruction == "Optional value: "


# Tests for nested state access feature
@pytest.mark.asyncio
async def test_inject_session_state_with_nested_dict_access():
instruction_template = (
"User name is {user.name} and role is {user.profile.role}"
)
invocation_context = await _create_test_readonly_context(
state={
"user": {
"name": "Alice",
"profile": {"role": "Engineer", "level": "Senior"},
}
}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "User name is Alice and role is Engineer"


@pytest.mark.asyncio
async def test_inject_session_state_with_deep_nested_access():
instruction_template = "Deep value: {level1.level2.level3.value}"
invocation_context = await _create_test_readonly_context(
state={
"level1": {
"level2": {"level3": {"value": "deep_data", "other": "ignored"}}
}
}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Deep value: deep_data"


@pytest.mark.asyncio
async def test_inject_session_state_with_optional_nested_access_existing():
instruction_template = "Name: {user?.name} Role: {user?.profile?.role}"
invocation_context = await _create_test_readonly_context(
state={
"user": {
"name": "Bob",
"profile": {"role": "Developer"},
}
}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Name: Bob Role: Developer"


@pytest.mark.asyncio
async def test_inject_session_state_with_optional_nested_access_missing():
instruction_template = "Name: {user?.name} Missing: {user?.missing?.field?}"
invocation_context = await _create_test_readonly_context(
state={"user": {"name": "Charlie"}}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Name: Charlie Missing: "


@pytest.mark.asyncio
async def test_inject_session_state_with_optional_nested_missing_root():
instruction_template = "Optional nested: {missing_root?.nested?.value?}"
invocation_context = await _create_test_readonly_context(state={})

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Optional nested: "


@pytest.mark.asyncio
async def test_inject_session_state_with_nested_none_value():
instruction_template = "Value: {user.profile.role}"
invocation_context = await _create_test_readonly_context(
state={"user": {"profile": None}}
)

# When a value in the path is None, it returns empty string
populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Value: "


@pytest.mark.asyncio
async def test_inject_session_state_with_optional_nested_none_value():
instruction_template = "Value: {user.profile?.role?}"
invocation_context = await _create_test_readonly_context(
state={"user": {"profile": None}}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Value: "


@pytest.mark.asyncio
async def test_inject_session_state_with_missing_nested_key_raises_error():
instruction_template = "Value: {user.profile.missing_key}"
invocation_context = await _create_test_readonly_context(
state={"user": {"profile": {"role": "Engineer"}}}
)

with pytest.raises(
KeyError, match="Context variable not found: `user.profile.missing_key`"
):
await instructions_utils.inject_session_state(
instruction_template, invocation_context
)


@pytest.mark.asyncio
async def test_inject_session_state_with_required_parent_missing_raises_error():
"""Test that {user.profile?} raises error when 'user' (required) is missing.

This verifies that optional chaining is per-segment, not for the whole path.
Even though 'profile?' is optional, 'user' is required and should raise error.
"""
instruction_template = "Value: {user.profile?}"
invocation_context = await _create_test_readonly_context(state={})

with pytest.raises(
KeyError, match="Context variable not found: `user.profile\\?`"
):
await instructions_utils.inject_session_state(
instruction_template, invocation_context
)


@pytest.mark.asyncio
async def test_inject_session_state_with_nested_and_prefixed_state():
instruction_template = "User: {app:user.name} Temp: {temp:session.id}"
invocation_context = await _create_test_readonly_context(
state={
"app:user": {"name": "Dana"},
"temp:session": {"id": "session_123"},
}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "User: Dana Temp: session_123"


@pytest.mark.asyncio
async def test_inject_session_state_with_mixed_nested_and_flat_state():
instruction_template = (
"Flat: {simple_key}, Nested: {user.name}, Deep: {config.app.version}"
)
invocation_context = await _create_test_readonly_context(
state={
"simple_key": "simple_value",
"user": {"name": "Eve"},
"config": {"app": {"version": "1.0.0"}},
}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Flat: simple_value, Nested: Eve, Deep: 1.0.0"


@pytest.mark.asyncio
async def test_inject_session_state_with_numeric_nested_values():
instruction_template = "Age: {user.age}, Score: {user.metrics.score}"
invocation_context = await _create_test_readonly_context(
state={"user": {"age": 25, "metrics": {"score": 95.5}}}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Age: 25, Score: 95.5"


@pytest.mark.asyncio
async def test_inject_session_state_with_nested_object_attribute_access():
"""Test accessing attributes on objects (not just dicts)"""

class UserProfile:

def __init__(self):
self.role = "Engineer"
self.department = "Engineering"

class User:

def __init__(self):
self.name = "Frank"
self.profile = UserProfile()

instruction_template = "Name: {user.name}, Role: {user.profile.role}"
invocation_context = await _create_test_readonly_context(
state={"user": User()}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Name: Frank, Role: Engineer"