diff --git a/sqlspec/observability/_dispatcher.py b/sqlspec/observability/_dispatcher.py index 2583167d..2887d5ce 100644 --- a/sqlspec/observability/_dispatcher.py +++ b/sqlspec/observability/_dispatcher.py @@ -37,9 +37,31 @@ class LifecycleDispatcher: """Dispatches lifecycle hooks with guard flags and diagnostics counters.""" - __slots__ = ("_hooks", "_counters", *GUARD_ATTRS) + __slots__ = ( + "_counters", + "_hooks", + "has_connection_create", + "has_connection_destroy", + "has_error", + "has_pool_create", + "has_pool_destroy", + "has_query_complete", + "has_query_start", + "has_session_end", + "has_session_start", + ) def __init__(self, hooks: "dict[str, Iterable[Any]] | None" = None) -> None: + self.has_pool_create = False + self.has_pool_destroy = False + self.has_connection_create = False + self.has_connection_destroy = False + self.has_session_start = False + self.has_session_end = False + self.has_query_start = False + self.has_query_complete = False + self.has_error = False + normalized: dict[LifecycleEvent, tuple[Any, ...]] = {} for event_name, guard_attr in zip(EVENT_ATTRS, GUARD_ATTRS, strict=False): callables = hooks.get(event_name) if hooks else None diff --git a/tests/unit/test_observability.py b/tests/unit/test_observability.py index 5511ca09..fadde02c 100644 --- a/tests/unit/test_observability.py +++ b/tests/unit/test_observability.py @@ -216,6 +216,25 @@ def observer(_event: Any) -> None: assert observer_called == [] # observers run via runtime, dispatcher unaffected +def test_lifecycle_dispatcher_guard_attributes_always_accessible() -> None: + """All guard attributes should be accessible even with no hooks (mypyc compatibility).""" + + dispatcher = LifecycleDispatcher(None) + assert dispatcher.has_pool_create is False + assert dispatcher.has_pool_destroy is False + assert dispatcher.has_connection_create is False + assert dispatcher.has_connection_destroy is False + assert dispatcher.has_session_start is False + assert dispatcher.has_session_end is False + assert dispatcher.has_query_start is False + assert dispatcher.has_query_complete is False + assert dispatcher.has_error is False + + dispatcher_with_hooks = LifecycleDispatcher(cast("dict[str, Iterable[Any]]", {"on_query_start": [lambda ctx: ctx]})) + assert dispatcher_with_hooks.has_query_start is True + assert dispatcher_with_hooks.has_pool_create is False + + def test_lifecycle_dispatcher_counts_events() -> None: """Lifecycle dispatcher should count emitted events for diagnostics."""