Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 178 additions & 20 deletions src/google/adk/plugins/bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
import logging
import mimetypes
import os

# Enable gRPC fork support so child processes created via os.fork()
# can safely create new gRPC channels. Must be set before grpc's
# C-core is loaded (which happens through the google.api_core
# imports below). setdefault respects any explicit user override.
os.environ.setdefault("GRPC_ENABLE_FORK_SUPPORT", "1")

import random
import time
from types import MappingProxyType
Expand Down Expand Up @@ -76,19 +83,29 @@
_SCHEMA_VERSION = "1"
_SCHEMA_VERSION_LABEL_KEY = "adk_schema_version"

# Human-in-the-loop (HITL) tool names that receive additional
# dedicated event types alongside the normal TOOL_* events.
_HITL_TOOL_NAMES = frozenset({
"adk_request_credential",
"adk_request_confirmation",
"adk_request_input",
})
_HITL_EVENT_MAP = MappingProxyType({
"adk_request_credential": "HITL_CREDENTIAL_REQUEST",
"adk_request_confirmation": "HITL_CONFIRMATION_REQUEST",
"adk_request_input": "HITL_INPUT_REQUEST",
})

# Track all living plugin instances so the fork handler can reset
# them proactively in the child, before _ensure_started runs.
_LIVE_PLUGINS: weakref.WeakSet = weakref.WeakSet()


def _after_fork_in_child() -> None:
"""Reset every living plugin instance after os.fork()."""
for plugin in list(_LIVE_PLUGINS):
try:
plugin._reset_runtime_state()
except Exception:
pass


if hasattr(os, "register_at_fork"):
os.register_at_fork(after_in_child=_after_fork_in_child)


def _safe_callback(func):
"""Decorator that catches and logs exceptions in plugin callbacks.
Expand Down Expand Up @@ -1407,7 +1424,10 @@ def process_text(t: str) -> tuple[str, bool]:
if content.config and getattr(content.config, "system_instruction", None):
si = content.config.system_instruction
if isinstance(si, str):
json_payload["system_prompt"] = si
truncated_si, trunc = process_text(si)
if trunc:
is_truncated = True
json_payload["system_prompt"] = truncated_si
else:
summary, parts, trunc = await self._parse_content_object(si)
if trunc:
Expand Down Expand Up @@ -1855,6 +1875,7 @@ def __init__(
self._schema = None
self.arrow_schema = None
self._init_pid = os.getpid()
_LIVE_PLUGINS.add(self)

def _cleanup_stale_loop_states(self) -> None:
"""Removes entries for event loops that have been closed."""
Expand Down Expand Up @@ -2142,9 +2163,73 @@ def _ensure_schema_exists(self) -> None:
exc_info=True,
)

@staticmethod
def _schema_fields_match(
existing: list[bq_schema.SchemaField],
desired: list[bq_schema.SchemaField],
) -> tuple[
list[bq_schema.SchemaField],
list[bq_schema.SchemaField],
]:
"""Compares existing vs desired schema fields recursively.

Returns:
A tuple of (new_top_level_fields, updated_record_fields).
``new_top_level_fields`` are fields in *desired* that are
entirely absent from *existing*.
``updated_record_fields`` are RECORD fields that exist in
both but have new sub-fields in *desired*; each entry is a
copy of the existing field with the missing sub-fields
appended.
"""
existing_by_name = {f.name: f for f in existing}
new_fields: list[bq_schema.SchemaField] = []
updated_records: list[bq_schema.SchemaField] = []

for desired_field in desired:
existing_field = existing_by_name.get(desired_field.name)
if existing_field is None:
new_fields.append(desired_field)
elif (
desired_field.field_type == "RECORD"
and existing_field.field_type == "RECORD"
and desired_field.fields
):
# Recurse into nested RECORD fields.
sub_new, sub_updated = (
BigQueryAgentAnalyticsPlugin._schema_fields_match(
list(existing_field.fields),
list(desired_field.fields),
)
)
if sub_new or sub_updated:
# Build a merged sub-field list.
merged_sub = list(existing_field.fields)
# Replace updated nested records in-place.
updated_names = {f.name for f in sub_updated}
merged_sub = [
next(u for u in sub_updated if u.name == f.name)
if f.name in updated_names
else f
for f in merged_sub
]
Comment on lines +2210 to +2215
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of next() inside this list comprehension results in a nested loop, which can be inefficient if sub_updated is large. You can optimize this by converting sub_updated to a dictionary for O(1) lookups.

          sub_updated_by_name = {f.name: f for f in sub_updated}
          merged_sub = [
              sub_updated_by_name.get(f.name, f)
              for f in merged_sub
          ]

# Append entirely new sub-fields.
merged_sub.extend(sub_new)
# Rebuild via API representation to preserve all
# existing field attributes (policy_tags, etc.).
api_repr = existing_field.to_api_repr()
api_repr["fields"] = [sf.to_api_repr() for sf in merged_sub]
updated_records.append(bq_schema.SchemaField.from_api_repr(api_repr))

return new_fields, updated_records

def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None:
"""Adds missing columns to an existing table (additive only).

Handles nested RECORD fields by recursing into sub-fields.
The version label is only stamped after a successful update
so that a failed attempt is retried on the next run.

Args:
existing_table: The current BigQuery table object.
"""
Expand All @@ -2154,24 +2239,43 @@ def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None:
if stored_version == _SCHEMA_VERSION:
return

existing_names = {f.name for f in existing_table.schema}
new_fields = [f for f in self._schema if f.name not in existing_names]
new_fields, updated_records = self._schema_fields_match(
list(existing_table.schema), list(self._schema)
)

if new_fields:
merged = list(existing_table.schema) + new_fields
if new_fields or updated_records:
# Build merged top-level schema.
updated_names = {f.name for f in updated_records}
merged = [
next(u for u in updated_records if u.name == f.name)
if f.name in updated_names
else f
for f in existing_table.schema
]
merged.extend(new_fields)
existing_table.schema = merged

change_desc = []
Comment on lines +2253 to +2258
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the other comment, this list comprehension can be made more efficient by using a dictionary for lookups instead of next() within the comprehension, which avoids a nested loop.

      updated_records_by_name = {f.name: f for f in updated_records}
      merged = [
          updated_records_by_name.get(f.name, f)
          for f in existing_table.schema
      ]

if new_fields:
change_desc.append(f"new columns {[f.name for f in new_fields]}")
if updated_records:
change_desc.append(
f"updated RECORD fields {[f.name for f in updated_records]}"
)
logger.info(
"Auto-upgrading table %s: adding columns %s",
"Auto-upgrading table %s: %s",
self.full_table_id,
[f.name for f in new_fields],
", ".join(change_desc),
)

# Always stamp the version label so we skip on next run.
labels = dict(existing_table.labels or {})
labels[_SCHEMA_VERSION_LABEL_KEY] = _SCHEMA_VERSION
existing_table.labels = labels

try:
# Stamp the version label inside the try block so that
# on failure the label is NOT persisted and the next run
# retries the upgrade.
labels = dict(existing_table.labels or {})
labels[_SCHEMA_VERSION_LABEL_KEY] = _SCHEMA_VERSION
existing_table.labels = labels

update_fields = ["schema", "labels"]
self.client.update_table(existing_table, update_fields)
except Exception as e:
Expand Down Expand Up @@ -2243,6 +2347,22 @@ async def shutdown(self, timeout: float | None = None) -> None:
if loop in self._loop_state_by_loop:
await self._loop_state_by_loop[loop].batch_processor.shutdown(timeout=t)

# 1b. Drain batch processors on other (non-current) loops.
for other_loop, state in self._loop_state_by_loop.items():
if other_loop is loop or other_loop.is_closed():
continue
try:
future = asyncio.run_coroutine_threadsafe(
state.batch_processor.shutdown(timeout=t),
other_loop,
)
future.result(timeout=t)
except Exception:
logger.warning(
"Could not drain batch processor on loop %s",
other_loop,
)

# 2. Close clients for all states
for state in self._loop_state_by_loop.values():
if state.write_client and getattr(
Expand Down Expand Up @@ -2298,6 +2418,38 @@ def _reset_runtime_state(self) -> None:
process. Pure-data fields like ``_schema`` and
``arrow_schema`` are kept because they are safe across fork.
"""
logger.warning(
"Fork detected (parent PID %s, child PID %s). Resetting"
" gRPC state for BigQuery analytics plugin. Note: gRPC"
" bidirectional streaming (used by the BigQuery Storage"
" Write API) is not fork-safe. If writes hang or time"
" out, configure the 'spawn' start method at your program"
" entry-point before creating child processes:"
" multiprocessing.set_start_method('spawn')",
self._init_pid,
os.getpid(),
)
# Best-effort: close inherited gRPC channels so broken
# finalizers don't interfere with newly created channels.
# For grpc.aio channels, close() is a coroutine. We cannot
# await here (called from sync context / fork handler), so
# we skip async channels and only close sync ones.
for loop_state in self._loop_state_by_loop.values():
wc = getattr(loop_state, "write_client", None)
transport = getattr(wc, "transport", None)
if transport is not None:
try:
channel = getattr(transport, "_grpc_channel", None)
if channel is not None and hasattr(channel, "close"):
result = channel.close()
# If close() returned a coroutine (grpc.aio channel),
# discard it to avoid unawaited-coroutine warnings.
if asyncio.iscoroutine(result):
result.close()
except Exception:
pass

# Clear all runtime state.
self._setup_lock = None
self.client = None
self._loop_state_by_loop = {}
Expand Down Expand Up @@ -2442,7 +2594,11 @@ def _enrich_attributes(
# Include session state if non-empty (contains user-set metadata
# like gchat thread-id, customer_id, etc.)
if session.state:
session_meta["state"] = dict(session.state)
truncated_state, _ = _recursive_smart_truncate(
dict(session.state),
self.config.max_content_length,
)
session_meta["state"] = truncated_state
attrs["session_metadata"] = session_meta
except Exception:
pass
Expand Down Expand Up @@ -2988,6 +3144,7 @@ async def on_model_error_callback(
"LLM_ERROR",
callback_context,
event_data=EventData(
status="ERROR",
error_message=str(error),
latency_ms=duration,
span_id_override=None if has_ambient else span_id,
Expand Down Expand Up @@ -3110,6 +3267,7 @@ async def on_tool_error_callback(
raw_content=content_dict,
is_truncated=is_truncated,
event_data=EventData(
status="ERROR",
error_message=str(error),
latency_ms=duration,
span_id_override=None if has_ambient else span_id,
Expand Down
Loading