Skip to content

Commit

Permalink
add core partition status storage query (#7662)
Browse files Browse the repository at this point in the history
* core partition status storage query

* fix

* add instance method

* fix instance call

* add repository_label argument to get_run_partition_data storage call

* fix test
  • Loading branch information
prha committed May 3, 2022
1 parent c55e3f4 commit de0eae9
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 3 deletions.
10 changes: 10 additions & 0 deletions python_modules/dagster/dagster/core/instance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
PipelineRun,
PipelineRunStatsSnapshot,
PipelineRunStatus,
RunPartitionData,
RunRecord,
RunsFilter,
TagBucket,
Expand Down Expand Up @@ -1247,6 +1248,15 @@ def get_run_records(
def supports_bucket_queries(self):
return self._run_storage.supports_bucket_queries

@traced
def get_run_partition_data(
self, partition_set_name: str, job_name: str, repository_label: str
) -> List[RunPartitionData]:
"""Get run partition data for a given partitioned job."""
return self._run_storage.get_run_partition_data(
partition_set_name, job_name, repository_label
)

def wipe(self):
self._run_storage.wipe()
self._event_storage.wipe()
Expand Down
31 changes: 31 additions & 0 deletions python_modules/dagster/dagster/core/storage/pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,37 @@ def __new__(
)


@whitelist_for_serdes
class RunPartitionData(
NamedTuple(
"_RunPartitionData",
[
("run_id", str),
("partition", str),
("status", DagsterRunStatus),
("start_time", Optional[float]),
("end_time", Optional[float]),
],
)
):
def __new__(
cls,
run_id: str,
partition: str,
status: DagsterRunStatus,
start_time: Optional[float],
end_time: Optional[float],
):
return super(RunPartitionData, cls).__new__(
cls,
run_id=check.str_param(run_id, "run_id"),
partition=check.str_param(partition, "partition"),
status=check.inst_param(status, "status", DagsterRunStatus),
start_time=check.opt_inst(start_time, float),
end_time=check.opt_inst(end_time, float),
)


###################################################################################################
# GRAVEYARD
#
Expand Down
10 changes: 10 additions & 0 deletions python_modules/dagster/dagster/core/storage/runs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dagster.core.storage.pipeline_run import (
JobBucket,
PipelineRun,
RunPartitionData,
RunRecord,
RunsFilter,
TagBucket,
Expand Down Expand Up @@ -322,6 +323,15 @@ def delete_run(self, run_id: str):
def supports_bucket_queries(self):
return True

@abstractmethod
def get_run_partition_data(
self,
partition_set_name: str,
job_name: str,
repository_label: str,
) -> List[RunPartitionData]:
"""Get run partition data for a given partitioned job."""

def migrate(self, print_fn: Optional[Callable] = None, force_rebuild_all: bool = False):
"""Call this method to run any required data migrations"""

Expand Down
37 changes: 36 additions & 1 deletion python_modules/dagster/dagster/core/storage/runs/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,18 @@
create_execution_plan_snapshot_id,
create_pipeline_snapshot_id,
)
from dagster.core.storage.tags import PARTITION_NAME_TAG, PARTITION_SET_TAG
from dagster.daemon.types import DaemonHeartbeat
from dagster.utils import EPOCH, frozendict, merge_dicts

from ..pipeline_run import JobBucket, PipelineRun, RunRecord, RunsFilter, TagBucket
from ..pipeline_run import (
JobBucket,
PipelineRun,
RunPartitionData,
RunRecord,
RunsFilter,
TagBucket,
)
from .base import RunStorage


Expand Down Expand Up @@ -330,6 +338,33 @@ def get_run_groups(
for root_run_id, run_group in root_run_id_to_group.items()
}

def get_run_partition_data(
self, partition_set_name: str, job_name: str, repository_label: str
) -> List[RunPartitionData]:
"""Get run partition data for a given partitioned job."""
check.str_param(partition_set_name, "partition_set_name")
check.str_param(job_name, "job_name")

run_filter = build_run_filter(
RunsFilter(pipeline_name=job_name, tags={PARTITION_SET_TAG: partition_set_name})
)
matching_runs = list(filter(run_filter, list(self._runs.values())[::-1]))
_partition_data_by_partition = {}
for run in matching_runs:
partition = run.tags.get(PARTITION_NAME_TAG)
if not partition or partition in _partition_data_by_partition:
continue

_partition_data_by_partition[partition] = RunPartitionData(
run_id=run.run_id,
partition=partition,
status=run.status,
start_time=None,
end_time=None,
)

return list(_partition_data_by_partition.values())

# Daemon Heartbeats

def add_daemon_heartbeat(self, daemon_heartbeat: DaemonHeartbeat):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@
from dagster.seven import JSONDecodeError
from dagster.utils import merge_dicts, utc_datetime_from_timestamp

from ..pipeline_run import JobBucket, PipelineRun, RunRecord, RunsFilter, TagBucket
from ..pipeline_run import (
DagsterRunStatus,
JobBucket,
PipelineRun,
RunPartitionData,
RunRecord,
RunsFilter,
TagBucket,
)
from .base import RunStorage
from .migration import OPTIONAL_DATA_MIGRATIONS, REQUIRED_DATA_MIGRATIONS, RUN_PARTITIONS
from .schema import (
Expand Down Expand Up @@ -792,6 +800,63 @@ def _get_snapshot(self, snapshot_id: str):

return defensively_unpack_pipeline_snapshot_query(logging, row) if row else None

def get_run_partition_data(
self, partition_set_name: str, job_name: str, repository_label: str
) -> List[RunPartitionData]:
if self.has_built_index(RUN_PARTITIONS) and self.has_run_stats_index_cols():
query = self._runs_query(
filters=RunsFilter(
pipeline_name=job_name,
tags={
PARTITION_SET_TAG: partition_set_name,
},
),
columns=["run_id", "status", "start_time", "end_time", "partition"],
)
rows = self.fetchall(query)

# dedup by partition
_partition_data_by_partition = {}
for row in rows:
if not row["partition"] or row["partition"] in _partition_data_by_partition:
continue

_partition_data_by_partition[row["partition"]] = RunPartitionData(
run_id=row["run_id"],
partition=row["partition"],
status=DagsterRunStatus[row["status"]],
start_time=row["start_time"],
end_time=row["end_time"],
)

return list(_partition_data_by_partition.values())
else:
query = self._runs_query(
filters=RunsFilter(
pipeline_name=job_name,
tags={
PARTITION_SET_TAG: partition_set_name,
},
),
)
rows = self.fetchall(query)
_partition_data_by_partition = {}
for row in rows:
run = self._row_to_run(row)
partition = run.tags.get(PARTITION_NAME_TAG)
if not partition or partition in _partition_data_by_partition:
continue

_partition_data_by_partition[partition] = RunPartitionData(
run_id=run.run_id,
partition=partition,
status=run.status,
start_time=None,
end_time=None,
)

return list(_partition_data_by_partition.values())

def _get_partition_runs(
self, partition_set_name: str, partition_name: str
) -> List[PipelineRun]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@
from dagster.core.storage.root import LocalArtifactStorage
from dagster.core.storage.runs.migration import REQUIRED_DATA_MIGRATIONS
from dagster.core.storage.runs.sql_run_storage import SqlRunStorage
from dagster.core.storage.tags import PARENT_RUN_ID_TAG, ROOT_RUN_ID_TAG
from dagster.core.storage.tags import (
PARENT_RUN_ID_TAG,
PARTITION_NAME_TAG,
PARTITION_SET_TAG,
ROOT_RUN_ID_TAG,
)
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
from dagster.core.utils import make_new_run_id
from dagster.daemon.daemon import SensorDaemon
Expand Down Expand Up @@ -1020,6 +1025,52 @@ def test_fetch_run_groups_ordering(self, storage):
assert first_root_run.run_id in run_groups
assert second_root_run.run_id not in run_groups

def test_partition_status(self, storage):
one = TestRunStorage.build_run(
run_id=make_new_run_id(),
pipeline_name="foo_pipeline",
status=PipelineRunStatus.FAILURE,
tags={
PARTITION_NAME_TAG: "one",
PARTITION_SET_TAG: "foo_set",
},
)
storage.add_run(one)
two = TestRunStorage.build_run(
run_id=make_new_run_id(),
pipeline_name="foo_pipeline",
status=PipelineRunStatus.FAILURE,
tags={
PARTITION_NAME_TAG: "two",
PARTITION_SET_TAG: "foo_set",
},
)
storage.add_run(two)
two_retried = TestRunStorage.build_run(
run_id=make_new_run_id(),
pipeline_name="foo_pipeline",
status=PipelineRunStatus.SUCCESS,
tags={
PARTITION_NAME_TAG: "two",
PARTITION_SET_TAG: "foo_set",
},
)
storage.add_run(two_retried)
three = TestRunStorage.build_run(
run_id=make_new_run_id(),
pipeline_name="foo_pipeline",
status=PipelineRunStatus.SUCCESS,
tags={
PARTITION_NAME_TAG: "three",
PARTITION_SET_TAG: "foo_set",
},
)
storage.add_run(three)
partition_data = storage.get_run_partition_data("foo_set", "foo_pipeline", "fake@fake")
assert len(partition_data) == 3
assert {_.partition for _ in partition_data} == {"one", "two", "three"}
assert {_.run_id for _ in partition_data} == {one.run_id, two_retried.run_id, three.run_id}

def _skip_in_memory(self, storage):
from dagster.core.storage.runs import InMemoryRunStorage

Expand Down

0 comments on commit de0eae9

Please sign in to comment.