Skip to content

Commit

Permalink
Pull out PostgresEventWatcher into its own module and parameterize it…
Browse files Browse the repository at this point in the history
… more (#7666)
  • Loading branch information
gibsondan committed May 2, 2022
1 parent c99184c commit d79bb7b
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
retry_mysql_creation_fn,
)

CHANNEL_NAME = "run_events"


class MySQLEventLogStorage(SqlEventLogStorage, ConfigurableClass):
"""MySQL-backed event log storage.
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .event_log import PostgresEventLogStorage
from .event_watcher import PostgresEventWatcher
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import logging
import threading
from collections import defaultdict
from typing import Callable, List, MutableMapping, Optional
from typing import Optional

import sqlalchemy as db

Expand All @@ -14,11 +11,9 @@
SqlEventLogStorageTable,
)
from dagster.core.storage.event_log.migration import ASSET_KEY_INDEX_COLS
from dagster.core.storage.event_log.polling_event_watcher import CallbackAfterCursor
from dagster.core.storage.sql import create_engine, run_alembic_upgrade, stamp_alembic_rev
from dagster.serdes import ConfigurableClass, ConfigurableClassData, deserialize_as

from ..pynotify import await_pg_notifications
from ..utils import (
create_pg_connection,
pg_alembic_config,
Expand All @@ -28,6 +23,7 @@
retry_pg_connection_fn,
retry_pg_creation_fn,
)
from .event_watcher import PostgresEventWatcher

CHANNEL_NAME = "run_events"

Expand Down Expand Up @@ -233,10 +229,23 @@ def enable_secondary_index(self, name):

def watch(self, run_id, start_cursor, callback):
if self._event_watcher is None:
self._event_watcher = PostgresEventWatcher(self.postgres_url, self._engine)
self._event_watcher = PostgresEventWatcher(
self.postgres_url,
[CHANNEL_NAME],
self._gen_event_log_entry_from_cursor,
)

self._event_watcher.watch_run(run_id, start_cursor, callback)

def _gen_event_log_entry_from_cursor(self, cursor) -> EventLogEntry:
with self._engine.connect() as conn:
cursor_res = conn.execute(
db.select([SqlEventLogStorageTable.c.event]).where(
SqlEventLogStorageTable.c.id == cursor
),
)
return deserialize_as(cursor_res.scalar(), EventLogEntry)

def end_watch(self, run_id, handler):
if self._event_watcher is None:
return
Expand All @@ -252,121 +261,3 @@ def dispose(self):
self._disposed = True
if self._event_watcher:
self._event_watcher.close()


POLLING_CADENCE = 0.25


def watcher_thread(
conn_string: str,
engine: db.engine.Engine,
handlers_dict: MutableMapping[str, List[CallbackAfterCursor]],
dict_lock: threading.Lock,
watcher_thread_exit: threading.Event,
watcher_thread_started: threading.Event,
):
for notif in await_pg_notifications(
conn_string,
channels=[CHANNEL_NAME],
timeout=POLLING_CADENCE,
yield_on_timeout=True,
exit_event=watcher_thread_exit,
started_event=watcher_thread_started,
):
if notif is None:
if watcher_thread_exit.is_set():
break
else:
run_id, index_str = notif.payload.split("_")
with dict_lock:
if run_id not in handlers_dict:
continue

index = int(index_str)
with dict_lock:
handlers = handlers_dict.get(run_id, [])

with engine.connect() as conn:
cursor_res = conn.execute(
db.select([SqlEventLogStorageTable.c.event]).where(
SqlEventLogStorageTable.c.id == index
),
)
dagster_event = deserialize_as(cursor_res.scalar(), EventLogEntry)

for callback_with_cursor in handlers:
if callback_with_cursor.start_cursor < index:
try:
callback_with_cursor.callback(dagster_event)
except Exception:
logging.exception(
"Exception in callback for event watch on run %s.", run_id
)


class PostgresEventWatcher:
def __init__(self, conn_string: str, engine: db.engine.Engine):
self._conn_string: str = check.str_param(conn_string, "conn_string")
self._engine = engine
self._handlers_dict: MutableMapping[str, List[CallbackAfterCursor]] = defaultdict(list)
self._dict_lock: threading.Lock = threading.Lock()
self._watcher_thread_exit: Optional[threading.Event] = None
self._watcher_thread_started: Optional[threading.Event] = None
self._watcher_thread: Optional[threading.Thread] = None

def watch_run(
self,
run_id: str,
start_cursor: int,
callback: Callable[[EventLogEntry], None],
start_timeout=15,
):
check.str_param(run_id, "run_id")
check.int_param(start_cursor, "start_cursor")
check.callable_param(callback, "callback")
if not self._watcher_thread:
self._watcher_thread_exit = threading.Event()
self._watcher_thread_started = threading.Event()

self._watcher_thread = threading.Thread(
target=watcher_thread,
args=(
self._conn_string,
self._engine,
self._handlers_dict,
self._dict_lock,
self._watcher_thread_exit,
self._watcher_thread_started,
),
name="postgres-event-watch",
)
self._watcher_thread.daemon = True
self._watcher_thread.start()

# Wait until the watcher thread is actually listening before returning
self._watcher_thread_started.wait(start_timeout)
if not self._watcher_thread_started.is_set():
raise Exception("Watcher thread never started")

with self._dict_lock:
self._handlers_dict[run_id].append(CallbackAfterCursor(start_cursor + 1, callback))

def unwatch_run(self, run_id: str, handler: Callable[[EventLogEntry], None]):
check.str_param(run_id, "run_id")
check.callable_param(handler, "handler")
with self._dict_lock:
if run_id in self._handlers_dict:
self._handlers_dict[run_id] = [
callback_with_cursor
for callback_with_cursor in self._handlers_dict[run_id]
if callback_with_cursor.callback != handler
]
if not self._handlers_dict[run_id]:
del self._handlers_dict[run_id]

def close(self):
if self._watcher_thread:
self._watcher_thread_exit.set()
self._watcher_thread.join()
self._watcher_thread_exit = None
self._watcher_thread = None
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import logging
import threading
from collections import defaultdict
from typing import Callable, List, MutableMapping, Optional

from dagster import check
from dagster.core.events.log import EventLogEntry
from dagster.core.storage.event_log.polling_event_watcher import CallbackAfterCursor

from ..pynotify import await_pg_notifications

POLLING_CADENCE = 0.25


def watcher_thread(
conn_string: str,
handlers_dict: MutableMapping[str, List[CallbackAfterCursor]],
dict_lock: threading.Lock,
watcher_thread_exit: threading.Event,
watcher_thread_started: threading.Event,
channels: List[str],
gen_event_log_entry_from_cursor: Callable[[int], EventLogEntry],
):
for notif in await_pg_notifications(
conn_string,
channels=channels,
timeout=POLLING_CADENCE,
yield_on_timeout=True,
exit_event=watcher_thread_exit,
started_event=watcher_thread_started,
):
if notif is None:
if watcher_thread_exit.is_set():
break
else:
run_id, index_str = notif.payload.split("_")
with dict_lock:
if run_id not in handlers_dict:
continue

index = int(index_str)
with dict_lock:
handlers = handlers_dict.get(run_id, [])

dagster_event = gen_event_log_entry_from_cursor(index)

for callback_with_cursor in handlers:
if callback_with_cursor.start_cursor < index:
try:
callback_with_cursor.callback(dagster_event)
except Exception:
logging.exception(
"Exception in callback for event watch on run %s.", run_id
)


class PostgresEventWatcher:
def __init__(
self,
conn_string: str,
channels: List[str],
gen_event_log_entry_from_cursor: Callable[[int], EventLogEntry],
):
self._conn_string: str = check.str_param(conn_string, "conn_string")
self._handlers_dict: MutableMapping[str, List[CallbackAfterCursor]] = defaultdict(list)
self._dict_lock: threading.Lock = threading.Lock()
self._watcher_thread_exit: Optional[threading.Event] = None
self._watcher_thread_started: Optional[threading.Event] = None
self._watcher_thread: Optional[threading.Thread] = None
self._channels: List[str] = check.list_param(channels, "channels")
self._gen_event_log_entry_from_cursor: Callable[
[int], EventLogEntry
] = check.callable_param(gen_event_log_entry_from_cursor, "gen_event_log_entry_from_cursor")

def watch_run(
self,
run_id: str,
start_cursor: int,
callback: Callable[[EventLogEntry], None],
start_timeout=15,
):
check.str_param(run_id, "run_id")
check.int_param(start_cursor, "start_cursor")
check.callable_param(callback, "callback")
if not self._watcher_thread:
self._watcher_thread_exit = threading.Event()
self._watcher_thread_started = threading.Event()

self._watcher_thread = threading.Thread(
target=watcher_thread,
args=(
self._conn_string,
self._handlers_dict,
self._dict_lock,
self._watcher_thread_exit,
self._watcher_thread_started,
self._channels,
self._gen_event_log_entry_from_cursor,
),
name="postgres-event-watch",
)
self._watcher_thread.daemon = True
self._watcher_thread.start()

# Wait until the watcher thread is actually listening before returning
self._watcher_thread_started.wait(start_timeout)
if not self._watcher_thread_started.is_set():
raise Exception("Watcher thread never started")

with self._dict_lock:
self._handlers_dict[run_id].append(CallbackAfterCursor(start_cursor + 1, callback))

def unwatch_run(self, run_id: str, handler: Callable[[EventLogEntry], None]):
check.str_param(run_id, "run_id")
check.callable_param(handler, "handler")
with self._dict_lock:
if run_id in self._handlers_dict:
self._handlers_dict[run_id] = [
callback_with_cursor
for callback_with_cursor in self._handlers_dict[run_id]
if callback_with_cursor.callback != handler
]
if not self._handlers_dict[run_id]:
del self._handlers_dict[run_id]

def close(self):
if self._watcher_thread:
self._watcher_thread_exit.set()
self._watcher_thread.join()
self._watcher_thread_exit = None
self._watcher_thread = None

0 comments on commit d79bb7b

Please sign in to comment.