Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import nox

@nox.session(python=["3.10", "3.11", "3.12", "3.13"])
def lint(session):
session.install("-e", ".")
session.install("pylint")
session.run("pylint", "src/google", "--rcfile=pylintrc")

@nox.session(python=["3.10", "3.11", "3.12", "3.13"])
def unit(session):
session.install("-e", ".[test]")
session.run("pytest", "tests/unittests")
35 changes: 34 additions & 1 deletion src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,8 @@ def __init__(
extra_plugins: Optional[list[str]] = None,
logo_text: Optional[str] = None,
logo_image_url: Optional[str] = None,
max_llm_calls: int = 500,
avatar_config: Optional[str] = None,
url_prefix: Optional[str] = None,
auto_create_session: bool = False,
trigger_sources: Optional[list[str]] = None,
Expand All @@ -675,10 +677,31 @@ def __init__(
self.runners_to_clean: set[str] = set()
self.current_app_name_ref: SharedValue[str] = SharedValue(value="")
self.runner_dict = {}
self.max_llm_calls = max_llm_calls
self.avatar_config = avatar_config
self.url_prefix = url_prefix
self.auto_create_session = auto_create_session
self.trigger_sources = trigger_sources

def _get_avatar_config(self) -> Optional[types.AvatarConfig]:
"""Parses avatar_config string or file into AvatarConfig object."""
if not self.avatar_config:
return None

try:
# Check if it's a file path
if os.path.isfile(self.avatar_config):
with open(self.avatar_config, "r", encoding="utf-8") as f:
config_dict = json.load(f)
else:
# Assume it's a JSON string
config_dict = json.loads(self.avatar_config)

return types.AvatarConfig.model_validate(config_dict)
except Exception as e:
logger.error("Failed to parse avatar_config: %s", e)
return None

async def get_runner_async(self, app_name: str) -> Runner:
"""Returns the cached runner for the given app."""
# Handle cleanup
Expand Down Expand Up @@ -1898,6 +1921,10 @@ async def run_agent(req: RunAgentRequest) -> list[Event]:
session_id=req.session_id,
new_message=req.new_message,
state_delta=req.state_delta,
run_config=RunConfig(
max_llm_calls=self.max_llm_calls,
avatar_config=self._get_avatar_config(),
),
invocation_id=req.invocation_id,
)
) as agen:
Expand Down Expand Up @@ -1940,7 +1967,11 @@ async def event_generator():
session_id=req.session_id,
new_message=req.new_message,
state_delta=req.state_delta,
run_config=RunConfig(streaming_mode=stream_mode),
run_config=RunConfig(
streaming_mode=stream_mode,
max_llm_calls=self.max_llm_calls,
avatar_config=self._get_avatar_config(),
),
invocation_id=req.invocation_id,
)
) as agen:
Expand Down Expand Up @@ -2119,6 +2150,8 @@ async def forward_events():
else None
),
save_live_blob=save_live_blob,
max_llm_calls=self.max_llm_calls,
avatar_config=self._get_avatar_config(),
)
async with Aclosing(
runner.run_live(
Expand Down
35 changes: 34 additions & 1 deletion src/google/adk/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ async def run_input_file(
credential_service: BaseCredentialService,
input_path: str,
memory_service: Optional[BaseMemoryService] = None,
max_llm_calls: int = 500,
avatar_config: Optional[types.AvatarConfig] = None,
) -> Session:
app = (
agent_or_app
Expand All @@ -81,7 +83,12 @@ async def run_input_file(
content = types.Content(role='user', parts=[types.Part(text=query)])
async with Aclosing(
runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
user_id=session.user_id,
session_id=session.id,
new_message=content,
run_config=RunConfig(
max_llm_calls=max_llm_calls, avatar_config=avatar_config
),
)
) as agen:
async for event in agen:
Expand All @@ -98,6 +105,8 @@ async def run_interactively(
session_service: BaseSessionService,
credential_service: BaseCredentialService,
memory_service: Optional[BaseMemoryService] = None,
max_llm_calls: int = 500,
avatar_config: Optional[types.AvatarConfig] = None,
) -> None:
app = (
root_agent_or_app
Expand All @@ -124,6 +133,9 @@ async def run_interactively(
new_message=types.Content(
role='user', parts=[types.Part(text=query)]
),
run_config=RunConfig(
max_llm_calls=max_llm_calls, avatar_config=avatar_config
),
)
) as agen:
async for event in agen:
Expand All @@ -145,6 +157,8 @@ async def run_cli(
artifact_service_uri: Optional[str] = None,
memory_service_uri: Optional[str] = None,
use_local_storage: bool = True,
max_llm_calls: int = 500,
avatar_config: Optional[str] = None,
) -> None:
"""Runs an interactive CLI for a certain agent.

Expand All @@ -170,6 +184,19 @@ async def run_cli(
user_id = 'test_user'

agents_dir = str(agent_parent_path)

avatar_config_obj = None
if avatar_config:
try:
if Path(avatar_config).is_file():
with open(avatar_config, "r", encoding="utf-8") as f:
config_dict = json.load(f)
else:
config_dict = json.loads(avatar_config)
avatar_config_obj = types.AvatarConfig.model_validate(config_dict)
except Exception as e:
click.secho(f"Warning: Failed to parse avatar_config: {e}", fg="yellow")

agent_loader = AgentLoader(agents_dir=agents_dir)
agent_or_app = agent_loader.load_agent(agent_folder_name)
session_app_name = (
Expand Down Expand Up @@ -224,6 +251,8 @@ def _print_event(event) -> None:
memory_service=memory_service,
credential_service=credential_service,
input_path=input_file,
max_llm_calls=max_llm_calls,
avatar_config=avatar_config_obj,
)
elif saved_session_file:
# Load the saved session from file
Expand All @@ -250,6 +279,8 @@ def _print_event(event) -> None:
session_service,
credential_service,
memory_service=memory_service,
max_llm_calls=max_llm_calls,
avatar_config=avatar_config_obj,
)
else:
session = await session_service.create_session(
Expand All @@ -263,6 +294,8 @@ def _print_event(event) -> None:
session_service,
credential_service,
memory_service=memory_service,
max_llm_calls=max_llm_calls,
avatar_config=avatar_config_obj,
)

if save_session:
Expand Down
41 changes: 41 additions & 0 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ def wrapper(*args, **kwargs):
@main.command("run", cls=HelpfulCommand)
@feature_options()
@adk_services_options(default_use_local_storage=True)
@fast_api_common_options()
@click.option(
"--save_session",
type=bool,
Expand Down Expand Up @@ -659,6 +660,20 @@ def cli_run(
session_id: Optional[str],
replay: Optional[str],
resume: Optional[str],
host: str = "127.0.0.1",
port: int = 8000,
allow_origins: Optional[list[str]] = None,
log_level: str = "INFO",
trace_to_cloud: bool = False,
otel_to_cloud: bool = False,
reload: bool = True,
a2a: bool = False,
reload_agents: bool = False,
eval_storage_uri: Optional[str] = None,
max_llm_calls: int = 500,
avatar_config: Optional[str] = None,
extra_plugins: Optional[list[str]] = None,
url_prefix: Optional[str] = None,
session_service_uri: Optional[str] = None,
artifact_service_uri: Optional[str] = None,
memory_service_uri: Optional[str] = None,
Expand Down Expand Up @@ -1537,6 +1552,23 @@ def decorator(func):
),
default=None,
)
@click.option(
"--max_llm_calls",
type=int,
default=500,
show_default=True,
help="Optional. Maximum number of LLM calls allowed for a given run.",
)
@click.option(
"--avatar_config",
type=str,
help=(
"Optional. JSON string or path to JSON file containing"
" avatar configuration for live sessions (e.g.,"
" '{\"avatarName\": \"avatar_id\"}')."
),
default=None,
)
@click.option(
"--extra_plugins",
help=(
Expand Down Expand Up @@ -1609,6 +1641,8 @@ def wrapper(ctx, *args, **kwargs):
def cli_web(
agents_dir: str,
eval_storage_uri: Optional[str] = None,
max_llm_calls: int = 500,
avatar_config: Optional[str] = None,
log_level: str = "INFO",
allow_origins: Optional[list[str]] = None,
host: str = "127.0.0.1",
Expand Down Expand Up @@ -1672,7 +1706,10 @@ async def _lifespan(app: FastAPI):
memory_service_uri=memory_service_uri,
use_local_storage=use_local_storage,
eval_storage_uri=eval_storage_uri,
max_llm_calls=max_llm_calls,
avatar_config=avatar_config,
allow_origins=allow_origins,

web=True,
trace_to_cloud=trace_to_cloud,
otel_to_cloud=otel_to_cloud,
Expand Down Expand Up @@ -1723,6 +1760,8 @@ async def _lifespan(app: FastAPI):
def cli_api_server(
agents_dir: str,
eval_storage_uri: Optional[str] = None,
max_llm_calls: int = 500,
avatar_config: Optional[str] = None,
log_level: str = "INFO",
allow_origins: Optional[list[str]] = None,
host: str = "127.0.0.1",
Expand Down Expand Up @@ -1764,6 +1803,8 @@ def cli_api_server(
memory_service_uri=memory_service_uri,
use_local_storage=use_local_storage,
eval_storage_uri=eval_storage_uri,
max_llm_calls=max_llm_calls,
avatar_config=avatar_config,
allow_origins=allow_origins,
web=False,
trace_to_cloud=trace_to_cloud,
Expand Down
2 changes: 2 additions & 0 deletions src/google/adk/cli/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def get_fast_api_app(
memory_service_uri: Optional[str] = None,
use_local_storage: bool = True,
eval_storage_uri: Optional[str] = None,
max_llm_calls: int = 500,
avatar_config: Optional[str] = None,
allow_origins: Optional[list[str]] = None,
web: bool,
a2a: bool = False,
Expand Down