Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion sqlspec/observability/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,31 @@
class LifecycleDispatcher:
"""Dispatches lifecycle hooks with guard flags and diagnostics counters."""

__slots__ = ("_hooks", "_counters", *GUARD_ATTRS)
__slots__ = (
"_counters",
"_hooks",
"has_connection_create",
"has_connection_destroy",
"has_error",
"has_pool_create",
"has_pool_destroy",
"has_query_complete",
"has_query_start",
"has_session_end",
"has_session_start",
)

def __init__(self, hooks: "dict[str, Iterable[Any]] | None" = None) -> None:
self.has_pool_create = False
self.has_pool_destroy = False
self.has_connection_create = False
self.has_connection_destroy = False
self.has_session_start = False
self.has_session_end = False
self.has_query_start = False
self.has_query_complete = False
self.has_error = False

normalized: dict[LifecycleEvent, tuple[Any, ...]] = {}
for event_name, guard_attr in zip(EVENT_ATTRS, GUARD_ATTRS, strict=False):
callables = hooks.get(event_name) if hooks else None
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/test_observability.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,25 @@ def observer(_event: Any) -> None:
assert observer_called == [] # observers run via runtime, dispatcher unaffected


def test_lifecycle_dispatcher_guard_attributes_always_accessible() -> None:
"""All guard attributes should be accessible even with no hooks (mypyc compatibility)."""

dispatcher = LifecycleDispatcher(None)
assert dispatcher.has_pool_create is False
assert dispatcher.has_pool_destroy is False
assert dispatcher.has_connection_create is False
assert dispatcher.has_connection_destroy is False
assert dispatcher.has_session_start is False
assert dispatcher.has_session_end is False
assert dispatcher.has_query_start is False
assert dispatcher.has_query_complete is False
assert dispatcher.has_error is False

dispatcher_with_hooks = LifecycleDispatcher(cast("dict[str, Iterable[Any]]", {"on_query_start": [lambda ctx: ctx]}))
assert dispatcher_with_hooks.has_query_start is True
assert dispatcher_with_hooks.has_pool_create is False


def test_lifecycle_dispatcher_counts_events() -> None:
"""Lifecycle dispatcher should count emitted events for diagnostics."""

Expand Down