Skip to content

Commit

Permalink
Attempt to speed up query for existing run keys for a given sensor (#…
Browse files Browse the repository at this point in the history
…8329)

Summary:
DB profiling suggests that the existing query here is quite expensive. Instead of doing a join across sensor name

The way that this could go wrong is if a user is sharing run key formats across lots of sensors - but I think that is less likely to happen than the current situation, and there are mitigations we can apply if we observe that.
  • Loading branch information
gibsondan committed Jun 11, 2022
1 parent e97cc73 commit d46e5d8
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 41 deletions.
19 changes: 15 additions & 4 deletions python_modules/dagster/dagster/core/storage/pipeline_run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import warnings
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Mapping, NamedTuple, Optional, Type
from typing import (
TYPE_CHECKING,
Any,
Dict,
FrozenSet,
List,
Mapping,
NamedTuple,
Optional,
Type,
Union,
)

import dagster._check as check
from dagster.core.definitions.events import AssetKey
Expand Down Expand Up @@ -510,7 +521,7 @@ class RunsFilter(
("run_ids", List[str]),
("job_name", Optional[str]),
("statuses", List[PipelineRunStatus]),
("tags", Dict[str, str]),
("tags", Dict[str, Union[str, List[str]]]),
("snapshot_id", Optional[str]),
("updated_after", Optional[datetime]),
("mode", Optional[str]),
Expand All @@ -523,7 +534,7 @@ def __new__(
run_ids: Optional[List[str]] = None,
job_name: Optional[str] = None,
statuses: Optional[List[PipelineRunStatus]] = None,
tags: Optional[Dict[str, str]] = None,
tags: Optional[Dict[str, Union[str, List[str]]]] = None,
snapshot_id: Optional[str] = None,
updated_after: Optional[datetime] = None,
mode: Optional[str] = None,
Expand All @@ -539,7 +550,7 @@ def __new__(
run_ids=check.opt_list_param(run_ids, "run_ids", of_type=str),
job_name=check.opt_str_param(job_name, "job_name"),
statuses=check.opt_list_param(statuses, "statuses", of_type=PipelineRunStatus),
tags=check.opt_dict_param(tags, "tags", key_type=str, value_type=str),
tags=check.opt_dict_param(tags, "tags", key_type=str),
snapshot_id=check.opt_str_param(snapshot_id, "snapshot_id"),
updated_after=check.opt_inst_param(updated_after, "updated_after", datetime),
mode=check.opt_str_param(mode, "mode"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def _filter(run: PipelineRun) -> bool:
return False

if filters.tags and not all(
run.tags.get(key) == value for key, value in filters.tags.items()
(run.tags.get(key) == value if isinstance(value, str) else run.tags.get(key) in value)
for key, value in filters.tags.items()
):
return False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,14 @@ def _add_filters_to_query(self, query, filters: RunsFilter):
query = query.where(
db.or_(
*(
db.and_(RunTagsTable.c.key == key, RunTagsTable.c.value == value)
db.and_(
RunTagsTable.c.key == key,
(
RunTagsTable.c.value == value
if isinstance(value, str)
else RunTagsTable.c.value.in_(value)
),
)
for key, value in filters.tags.items()
)
)
Expand Down
44 changes: 12 additions & 32 deletions python_modules/dagster/dagster/daemon/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
TickData,
TickStatus,
)
from dagster.core.storage.pipeline_run import PipelineRun, PipelineRunStatus, RunsFilter, TagBucket
from dagster.core.storage.tags import RUN_KEY_TAG
from dagster.core.storage.pipeline_run import PipelineRun, PipelineRunStatus, RunsFilter
from dagster.core.storage.tags import RUN_KEY_TAG, SENSOR_NAME_TAG
from dagster.core.telemetry import SENSOR_RUN_CREATED, hash_name, log_action
from dagster.core.workspace import IWorkspace
from dagster.utils import merge_dicts
Expand Down Expand Up @@ -512,38 +512,18 @@ def _fetch_existing_runs(instance, external_sensor, run_requests):
if not run_keys:
return {}

existing_runs = {}
runs_with_run_keys = instance.get_runs(filters=RunsFilter(tags={RUN_KEY_TAG: run_keys}))

if instance.supports_bucket_queries:
runs = instance.get_runs(
filters=RunsFilter(
tags=PipelineRun.tags_for_sensor(external_sensor),
),
bucket_by=TagBucket(
tag_key=RUN_KEY_TAG,
bucket_limit=1,
tag_values=run_keys,
),
)
for run in runs:
tags = run.tags or {}
run_key = tags.get(RUN_KEY_TAG)
existing_runs[run_key] = run
return existing_runs
valid_runs = [
run for run in runs_with_run_keys if run.tags.get(SENSOR_NAME_TAG) == external_sensor.name
]

existing_runs = {}
for run in valid_runs:
tags = run.tags or {}
run_key = tags.get(RUN_KEY_TAG)
existing_runs[run_key] = run

else:
for run_key in run_keys:
runs = instance.get_runs(
filters=RunsFilter(
tags=merge_dicts(
PipelineRun.tags_for_sensor(external_sensor),
{RUN_KEY_TAG: run_key},
)
),
limit=1,
)
if runs:
existing_runs[run_key] = runs[0]
return existing_runs


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def test_fetch_by_filter(self, storage):
one = make_new_run_id()
two = make_new_run_id()
three = make_new_run_id()
four = make_new_run_id()

storage.add_run(
TestRunStorage.build_run(
Expand All @@ -360,7 +361,16 @@ def test_fetch_by_filter(self, storage):
)
)

assert len(storage.get_runs()) == 3
storage.add_run(
TestRunStorage.build_run(
run_id=four,
pipeline_name="some_other_pipeline",
tags={"tag": "goodbye"},
status=PipelineRunStatus.FAILURE,
),
)

assert len(storage.get_runs()) == 4

some_runs = storage.get_runs(RunsFilter(run_ids=[one]))
count = storage.get_runs_count(RunsFilter(run_ids=[one]))
Expand Down Expand Up @@ -406,6 +416,19 @@ def test_fetch_by_filter(self, storage):
assert some_runs[0].run_id == two
assert some_runs[1].run_id == one

runs_with_multiple_tag_values = storage.get_runs(
RunsFilter(tags={"tag": ["hello", "goodbye", "farewell"]})
)
assert len(runs_with_multiple_tag_values) == 3
assert runs_with_multiple_tag_values[0].run_id == four
assert runs_with_multiple_tag_values[1].run_id == two
assert runs_with_multiple_tag_values[2].run_id == one

count_with_multiple_tag_values = storage.get_runs_count(
RunsFilter(tags={"tag": ["hello", "goodbye", "farewell"]})
)
assert count_with_multiple_tag_values == 3

some_runs = storage.get_runs(
RunsFilter(
pipeline_name="some_pipeline",
Expand Down Expand Up @@ -447,8 +470,8 @@ def test_fetch_by_filter(self, storage):

some_runs = storage.get_runs(RunsFilter())
count = storage.get_runs_count(RunsFilter())
assert len(some_runs) == 3
assert count == 3
assert len(some_runs) == 4
assert count == 4

def test_fetch_count_by_tag(self, storage):
assert storage
Expand Down

0 comments on commit d46e5d8

Please sign in to comment.