Skip to content

Commit

Permalink
fix tag batching with tag-based filters (#7055)
Browse files Browse the repository at this point in the history
* fix tag batching with tag-based filters

* black; fix tests

* reorganize for readability

* simplify query
  • Loading branch information
prha committed Mar 15, 2022
1 parent 8a611ae commit 7a78835
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 61 deletions.
140 changes: 86 additions & 54 deletions python_modules/dagster/dagster/core/storage/runs/sql_run_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,22 +240,6 @@ def _add_filters_to_query(self, query, filters: RunsFilter):

return query

def _bucket_rank_column(self, bucket_by, order_by, ascending):
check.inst_param(bucket_by, "bucket_by", (JobBucket, TagBucket))
check.invariant(
self.supports_bucket_queries, "Bucket queries are not supported by this storage layer"
)
sorting_column = getattr(RunsTable.c, order_by) if order_by else RunsTable.c.id
direction = db.asc if ascending else db.desc
bucket_column = (
RunsTable.c.pipeline_name if isinstance(bucket_by, JobBucket) else RunTagsTable.c.value
)
return (
db.func.rank()
.over(order_by=direction(sorting_column), partition_by=bucket_column)
.label("rank")
)

def _runs_query(
self,
filters: Optional[RunsFilter] = None,
Expand All @@ -276,54 +260,102 @@ def _runs_query(
if columns is None:
columns = ["run_body"]

query_columns = [getattr(RunsTable.c, column) for column in columns]

if bucket_by:
if limit or cursor:
check.failed("cannot specify bucket_by and limit/cursor at the same time")
return self._bucketed_runs_query(bucket_by, filters, columns, order_by, ascending)

# this is a bucketed query, so we need to calculate rank to apply bucket-based limits
# and ordering
query_columns.append(self._bucket_rank_column(bucket_by, order_by, ascending))

if isinstance(bucket_by, JobBucket):
base_query = (
db.select(query_columns)
.select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
if filters.tags
else RunsTable
)
.where(RunsTable.c.pipeline_name.in_(bucket_by.job_names))
)
else:
check.invariant(isinstance(bucket_by, TagBucket))
base_query = (
db.select(query_columns)
.select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
)
.where(RunTagsTable.c.key == bucket_by.tag_key)
.where(RunTagsTable.c.value.in_(bucket_by.tag_values))
query_columns = [getattr(RunsTable.c, column) for column in columns]
if filters.tags:
base_query = db.select(query_columns).select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
)
else:
base_query = db.select(query_columns).select_from(RunsTable)

base_query = self._add_filters_to_query(base_query, filters)
return self._add_cursor_limit_to_query(base_query, cursor, limit, order_by, ascending)

def _bucket_rank_column(self, bucket_by, order_by, ascending):
check.inst_param(bucket_by, "bucket_by", (JobBucket, TagBucket))
check.invariant(
self.supports_bucket_queries, "Bucket queries are not supported by this storage layer"
)
sorting_column = getattr(RunsTable.c, order_by) if order_by else RunsTable.c.id
direction = db.asc if ascending else db.desc
bucket_column = (
RunsTable.c.pipeline_name if isinstance(bucket_by, JobBucket) else RunTagsTable.c.value
)
return (
db.func.rank()
.over(order_by=direction(sorting_column), partition_by=bucket_column)
.label("rank")
)

def _bucketed_runs_query(
self,
bucket_by: Union[JobBucket, TagBucket],
filters: RunsFilter,
columns: List[str],
order_by: Optional[str] = None,
ascending: bool = False,
):
bucket_rank = self._bucket_rank_column(bucket_by, order_by, ascending)
query_columns = [getattr(RunsTable.c, column) for column in columns] + [bucket_rank]

if isinstance(bucket_by, JobBucket):
# bucketing by job
base_query = (
db.select(query_columns)
.select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
if filters.tags
else RunsTable
)
.where(RunsTable.c.pipeline_name.in_(bucket_by.job_names))
)
base_query = self._add_filters_to_query(base_query, filters)

elif not filters.tags:
# bucketing by tag, no tag filters
base_query = (
db.select(query_columns)
.select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
)
.where(RunTagsTable.c.key == bucket_by.tag_key)
.where(RunTagsTable.c.value.in_(bucket_by.tag_values))
)
base_query = self._add_filters_to_query(base_query, filters)
subquery = base_query.alias("subquery")
# select all the columns minus the rank column
subquery_columns = [getattr(subquery.c, column) for column in columns]
query = db.select(subquery_columns).order_by(subquery.c.rank.asc())
if bucket_by.bucket_limit:
query = query.where(subquery.c.rank <= bucket_by.bucket_limit)

else:
if filters.tags:
base_query = db.select(query_columns).select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
# there are tag filters as well as tag buckets, so we have to apply the tag filters in
# a separate join
filtered_query = db.select([RunsTable.c.run_id]).select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
)
filtered_query = self._add_filters_to_query(filtered_query, filters)
filtered_query = filtered_query.alias("filtered_query")

base_query = (
db.select(query_columns)
.select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id).join(
filtered_query, RunsTable.c.run_id == filtered_query.c.run_id
)
)
else:
base_query = db.select(query_columns).select_from(RunsTable)
.where(RunTagsTable.c.key == bucket_by.tag_key)
.where(RunTagsTable.c.value.in_(bucket_by.tag_values))
)

subquery = base_query.alias("subquery")

query = self._add_filters_to_query(base_query, filters)
query = self._add_cursor_limit_to_query(query, cursor, limit, order_by, ascending)
# select all the columns, but skip the bucket_rank column, which is only used for applying
# the limit / order
subquery_columns = [getattr(subquery.c, column) for column in columns]
query = db.select(subquery_columns).order_by(subquery.c.rank.asc())
if bucket_by.bucket_limit:
query = query.where(subquery.c.rank <= bucket_by.bucket_limit)

return query

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1285,13 +1285,13 @@ def _add_run(job_name, tags=None):
)
)

_one = _add_run("a", tags={"a": "1"})
_two = _add_run("a", tags={"a": "2"})
three = _add_run("a", tags={"a": "3"})
_none = _add_run("a")
b = _add_run("b", tags={"a": "4"})
one = _add_run("a", tags={"a": "1"})
two = _add_run("a", tags={"a": "2"})
_one = _add_run("a", tags={"a": "1", "b": "1"})
_two = _add_run("a", tags={"a": "2", "b": "1"})
three = _add_run("a", tags={"a": "3", "b": "1"})
_none = _add_run("a", tags={"b": "1"})
b = _add_run("b", tags={"a": "4", "b": "2"})
one = _add_run("a", tags={"a": "1", "b": "1"})
two = _add_run("a", tags={"a": "2", "b": "1"})

runs_by_tag = {
run.tags.get("a"): run
Expand All @@ -1305,6 +1305,7 @@ def _add_run(job_name, tags=None):
assert runs_by_tag.get("3").run_id == three.run_id
assert runs_by_tag.get("4").run_id == b.run_id

# fetch with a pipeline_name filter applied
runs_by_tag = {
run.tags.get("a"): run
for run in storage.get_runs(
Expand All @@ -1317,6 +1318,19 @@ def _add_run(job_name, tags=None):
assert runs_by_tag.get("2").run_id == two.run_id
assert runs_by_tag.get("3").run_id == three.run_id

# fetch with a tags filter applied
runs_by_tag = {
run.tags.get("a"): run
for run in storage.get_runs(
filters=RunsFilter(tags={"b": "1"}),
bucket_by=TagBucket(tag_key="a", tag_values=["1", "2", "3", "4"], bucket_limit=1),
)
}
assert set(runs_by_tag.keys()) == {"1", "2", "3"}
assert runs_by_tag.get("1").run_id == one.run_id
assert runs_by_tag.get("2").run_id == two.run_id
assert runs_by_tag.get("3").run_id == three.run_id

def test_run_record_timestamps(self, storage):
assert storage

Expand Down

0 comments on commit 7a78835

Please sign in to comment.