Skip to content

Commit

Permalink
require get_event_records to have filter arg (#8284)
Browse files Browse the repository at this point in the history
  • Loading branch information
prha committed Jun 13, 2022
1 parent 7bdf0ce commit b9ac2a3
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 67 deletions.
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/core/instance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,7 @@ def get_latest_materialization_events(
@traced
def get_event_records(
self,
event_records_filter: Optional["EventRecordsFilter"] = None,
event_records_filter: "EventRecordsFilter",
limit: Optional[int] = None,
ascending: bool = False,
) -> Iterable["EventLogRecord"]:
Expand Down
16 changes: 5 additions & 11 deletions python_modules/dagster/dagster/core/storage/event_log/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import base64
import warnings
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
Expand Down Expand Up @@ -131,7 +130,7 @@ class EventRecordsFilter(
NamedTuple(
"_EventRecordsFilter",
[
("event_type", Optional[DagsterEventType]),
("event_type", DagsterEventType),
("asset_key", Optional[AssetKey]),
("asset_partitions", Optional[List[str]]),
("after_cursor", Optional[Union[int, RunShardedEventsCursor]]),
Expand All @@ -144,7 +143,7 @@ class EventRecordsFilter(
"""Defines a set of filter fields for fetching a set of event log entries or event log records.
Args:
event_type (Optional[DagsterEventType]): Filter argument for dagster event type
event_type (DagsterEventType): Filter argument for dagster event type
asset_key (Optional[AssetKey]): Asset key for which to get asset materialization event
entries / records.
asset_partitions (Optional[List[str]]): Filter parameter such that only asset
Expand All @@ -166,7 +165,7 @@ class EventRecordsFilter(

def __new__(
cls,
event_type: Optional[DagsterEventType] = None,
event_type: DagsterEventType,
asset_key: Optional[AssetKey] = None,
asset_partitions: Optional[List[str]] = None,
after_cursor: Optional[Union[int, RunShardedEventsCursor]] = None,
Expand All @@ -175,12 +174,7 @@ def __new__(
before_timestamp: Optional[float] = None,
):
check.opt_list_param(asset_partitions, "asset_partitions", of_type=str)
event_type = check.opt_inst_param(event_type, "event_type", DagsterEventType)
if not event_type:
warnings.warn(
"The use of `EventRecordsFilter` without an event type is deprecated and will "
"begin erroring starting in 0.15.0"
)
check.inst_param(event_type, "event_type", DagsterEventType)

return super(EventRecordsFilter, cls).__new__(
cls,
Expand Down Expand Up @@ -316,7 +310,7 @@ def optimize_for_dagit(self, statement_timeout: int):
@abstractmethod
def get_event_records(
self,
event_records_filter: Optional[EventRecordsFilter] = None,
event_records_filter: EventRecordsFilter,
limit: Optional[int] = None,
ascending: bool = False,
) -> Iterable[EventLogRecord]:
Expand Down
32 changes: 7 additions & 25 deletions python_modules/dagster/dagster/core/storage/event_log/in_memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import time
import warnings
from collections import OrderedDict, defaultdict
from typing import Dict, Iterable, Mapping, Optional, Sequence, Set, cast

Expand Down Expand Up @@ -184,41 +183,24 @@ def is_persistent(self):

def get_event_records(
self,
event_records_filter: Optional[EventRecordsFilter] = None,
event_records_filter: EventRecordsFilter,
limit: Optional[int] = None,
ascending: bool = False,
) -> Iterable[EventLogRecord]:
if not event_records_filter:
warnings.warn(
"The use of `get_event_records` without an `EventRecordsFilter` is deprecated and "
"will begin erroring starting in 0.15.0"
)

after_id = (
(
event_records_filter.after_cursor.id
if isinstance(event_records_filter.after_cursor, RunShardedEventsCursor)
else event_records_filter.after_cursor
)
if event_records_filter
else None
event_records_filter.after_cursor.id
if isinstance(event_records_filter.after_cursor, RunShardedEventsCursor)
else event_records_filter.after_cursor
)
before_id = (
(
event_records_filter.before_cursor.id
if isinstance(event_records_filter.before_cursor, RunShardedEventsCursor)
else event_records_filter.before_cursor
)
if event_records_filter
else None
event_records_filter.before_cursor.id
if isinstance(event_records_filter.before_cursor, RunShardedEventsCursor)
else event_records_filter.before_cursor
)

filtered_events = []

def _apply_filters(record):
if not event_records_filter:
return True

if (
event_records_filter.event_type
and record.dagster_event.event_type_value != event_records_filter.event_type.value
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import warnings
from abc import abstractmethod
from collections import OrderedDict
from datetime import datetime
Expand Down Expand Up @@ -594,18 +593,13 @@ def enable_secondary_index(self, name):
def _apply_filter_to_query(
self,
query,
event_records_filter=None,
event_records_filter,
asset_details=None,
apply_cursor_filters=True,
):
if not event_records_filter:
return query

if event_records_filter.event_type:
query = query.where(
SqlEventLogStorageTable.c.dagster_event_type
== event_records_filter.event_type.value
)
query = query.where(
SqlEventLogStorageTable.c.dagster_event_type == event_records_filter.event_type.value
)

if event_records_filter.asset_key:
query = query.where(
Expand Down Expand Up @@ -667,23 +661,17 @@ def _apply_filter_to_query(

def get_event_records(
self,
event_records_filter: Optional[EventRecordsFilter] = None,
event_records_filter: EventRecordsFilter,
limit: Optional[int] = None,
ascending: bool = False,
) -> Iterable[EventLogRecord]:
"""Returns a list of (record_id, record)."""
check.opt_inst_param(event_records_filter, "event_records_filter", EventRecordsFilter)
check.inst_param(event_records_filter, "event_records_filter", EventRecordsFilter)
check.opt_int_param(limit, "limit")
check.bool_param(ascending, "ascending")

if not event_records_filter:
warnings.warn(
"The use of `get_event_records` without an `EventRecordsFilter` is deprecated and "
"will begin erroring starting in 0.15.0"
)

query = db.select([SqlEventLogStorageTable.c.id, SqlEventLogStorageTable.c.event])
if event_records_filter and event_records_filter.asset_key:
if event_records_filter.asset_key:
asset_details = next(iter(self._get_assets_details([event_records_filter.asset_key])))
else:
asset_details = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def store_event(self, event):

def get_event_records(
self,
event_records_filter: Optional[EventRecordsFilter] = None,
event_records_filter: EventRecordsFilter,
limit: Optional[int] = None,
ascending: bool = False,
) -> Iterable[EventLogRecord]:
Expand All @@ -271,15 +271,13 @@ def get_event_records(
)

query = db.select([SqlEventLogStorageTable.c.id, SqlEventLogStorageTable.c.event])
if event_records_filter and event_records_filter.asset_key:
if event_records_filter.asset_key:
asset_details = next(iter(self._get_assets_details([event_records_filter.asset_key])))
else:
asset_details = None

if (
event_records_filter
and event_records_filter.after_cursor != None
and not isinstance(event_records_filter.after_cursor, RunShardedEventsCursor)
if event_records_filter.after_cursor != None and not isinstance(
event_records_filter.after_cursor, RunShardedEventsCursor
):
raise Exception(
"""
Expand Down Expand Up @@ -307,8 +305,7 @@ def get_event_records(
# whose events may qualify the query, and then open run_connection per run_id at a time.
run_updated_after = (
event_records_filter.after_cursor.run_updated_after
if event_records_filter
and isinstance(event_records_filter.after_cursor, RunShardedEventsCursor)
if isinstance(event_records_filter.after_cursor, RunShardedEventsCursor)
else None
)
run_records = self._instance.get_run_records(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,11 @@ def _solids():
)

assert asset_key in set(storage.all_asset_keys())
_records = storage.get_event_records(EventRecordsFilter(asset_key=asset_key))
_records = storage.get_event_records(
EventRecordsFilter(
event_type=DagsterEventType.ASSET_MATERIALIZATION, asset_key=asset_key
)
)
assert len(_logs) == 1
assert re.match("Could not resolve event record as EventLogEntry", _logs[0])

Expand Down Expand Up @@ -980,7 +984,11 @@ def _solids():
)
)
assert asset_key in set(storage.all_asset_keys())
_records = storage.get_event_records(EventRecordsFilter(asset_key=asset_key))
_records = storage.get_event_records(
EventRecordsFilter(
event_type=DagsterEventType.ASSET_MATERIALIZATION, asset_key=asset_key
)
)
assert len(_logs) == 1
assert re.match("Could not parse event record id", _logs[0])

Expand Down Expand Up @@ -1718,12 +1726,16 @@ def solid_partitioned(context):
storage.store_event(event)

records = storage.get_event_records(
EventRecordsFilter(asset_key=AssetKey("asset_key"))
EventRecordsFilter(
event_type=DagsterEventType.ASSET_MATERIALIZATION,
asset_key=AssetKey("asset_key"),
)
)
assert len(records) == 4

records = storage.get_event_records(
EventRecordsFilter(
event_type=DagsterEventType.ASSET_MATERIALIZATION,
asset_key=AssetKey("asset_key"),
asset_partitions=["partition_a", "partition_b"],
)
Expand Down

0 comments on commit b9ac2a3

Please sign in to comment.