Skip to content

Commit

Permalink
double-write selector_id to ticks, jobs, instigators 3/5 (#7191)
Browse files Browse the repository at this point in the history
* add secondary index table migration to track schedule data migrations

* schema migration stuff

* fix sqlite migration

* add schema for instigators table, keyed by selector

* switch check for migration - fix mysql

* fix mysql backcompat tests to start from clean slate

* fix up comments, saving repository_selector_id

* double-write selector_id to ticks, jobs, instigators

* add repository name

* fix wipe

* fix mysql backcompat tests to start from clean slate

* update from comments, fix migration

* fix imports

* fix backcompat tests
  • Loading branch information
prha committed Apr 1, 2022
1 parent ae93b45 commit 6262a51
Show file tree
Hide file tree
Showing 23 changed files with 8,756 additions and 48 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dagster import check
from dagster.core.definitions.run_request import InstigatorType
from dagster.core.host_representation import InstigationSelector
from dagster.core.host_representation import InstigatorSelector
from dagster.core.scheduler.instigation import InstigatorStatus

from .utils import capture_error
Expand Down Expand Up @@ -44,7 +44,7 @@ def get_unloadable_instigator_states_or_error(graphene_info, instigator_type=Non
def get_instigator_state_or_error(graphene_info, selector):
from ..schema.instigation import GrapheneInstigationState

check.inst_param(selector, "selector", InstigationSelector)
check.inst_param(selector, "selector", InstigatorSelector)
location = graphene_info.context.get_repository_location(selector.location_name)
repository = location.get_repository(selector.repository_name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dagster.core.definitions.events import AssetKey
from dagster.core.execution.backfill import BulkActionStatus
from dagster.core.host_representation import (
InstigationSelector,
InstigatorSelector,
RepositorySelector,
ScheduleSelector,
SensorSelector,
Expand Down Expand Up @@ -340,7 +340,7 @@ def resolve_sensorsOrError(self, graphene_info, **kwargs):

def resolve_instigationStateOrError(self, graphene_info, instigationSelector):
return get_instigator_state_or_error(
graphene_info, InstigationSelector.from_graphql_input(instigationSelector)
graphene_info, InstigatorSelector.from_graphql_input(instigationSelector)
)

def resolve_unloadableInstigationStatesOrError(self, graphene_info, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def test_sensor_ticks_filtered(graphql_context):
instigator_type=InstigatorType.SENSOR,
status=TickStatus.STARTED,
timestamp=now.timestamp(),
selector_id=external_sensor.selector_id,
)
)

Expand All @@ -621,6 +622,7 @@ def test_sensor_ticks_filtered(graphql_context):
instigator_type=InstigatorType.SENSOR,
status=TickStatus.SKIPPED,
timestamp=now.timestamp(),
selector_id=external_sensor.selector_id,
)
)

Expand All @@ -633,6 +635,7 @@ def test_sensor_ticks_filtered(graphql_context):
status=TickStatus.FAILURE,
timestamp=now.timestamp(),
error=SerializableErrorInfo(message="foobar", stack=[], cls_name=None, cause=None),
selector_id=external_sensor.selector_id,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from .represented import RepresentedPipeline
from .selector import (
GraphSelector,
InstigationSelector,
InstigatorSelector,
PipelineSelector,
RepositorySelector,
ScheduleSelector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dagster.core.origin import PipelinePythonOrigin
from dagster.core.snap import ExecutionPlanSnapshot
from dagster.core.utils import toposort
from dagster.serdes import create_snapshot_id
from dagster.utils.schedules import schedule_execution_time_iterator

from .external_data import (
Expand All @@ -27,6 +28,7 @@
from .handle import InstigatorHandle, PartitionSetHandle, PipelineHandle, RepositoryHandle
from .pipeline_index import PipelineIndex
from .represented import RepresentedPipeline
from .selector import InstigatorSelector

if TYPE_CHECKING:
from dagster.core.scheduler.instigation import InstigatorState
Expand Down Expand Up @@ -490,6 +492,16 @@ def get_external_origin(self):
def get_external_origin_id(self):
return self.get_external_origin().get_id()

@property
def selector_id(self):
return create_snapshot_id(
InstigatorSelector(
self.handle.location_name,
self.handle.repository_name,
self._external_schedule_data.name,
)
)

@property
def default_status(self) -> DefaultScheduleStatus:
return (
Expand Down Expand Up @@ -602,6 +614,16 @@ def get_external_origin(self):
def get_external_origin_id(self):
return self.get_external_origin().get_id()

@property
def selector_id(self):
return create_snapshot_id(
InstigatorSelector(
self.handle.location_name,
self.handle.repository_name,
self._external_sensor_data.name,
)
)

def get_current_instigator_state(self, stored_state: Optional["InstigatorState"]):
from dagster.core.scheduler.instigation import (
InstigatorState,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, NamedTuple, Optional

from dagster import check
from dagster.serdes import whitelist_for_serdes


class PipelineSelector(
Expand Down Expand Up @@ -53,6 +54,7 @@ def with_solid_selection(self, solid_selection):
)


@whitelist_for_serdes
class RepositorySelector(
NamedTuple("_RepositorySelector", [("location_name", str), ("repository_name", str)])
):
Expand Down Expand Up @@ -136,13 +138,14 @@ def from_graphql_input(graphql_data):
)


class InstigationSelector(
@whitelist_for_serdes
class InstigatorSelector(
NamedTuple(
"_InstigationSelector", [("location_name", str), ("repository_name", str), ("name", str)]
"_InstigatorSelector", [("location_name", str), ("repository_name", str), ("name", str)]
)
):
def __new__(cls, location_name: str, repository_name: str, name: str):
return super(InstigationSelector, cls).__new__(
return super(InstigatorSelector, cls).__new__(
cls,
location_name=check.str_param(location_name, "location_name"),
repository_name=check.str_param(repository_name, "repository_name"),
Expand All @@ -158,7 +161,7 @@ def to_graphql_input(self):

@staticmethod
def from_graphql_input(graphql_data):
return InstigationSelector(
return InstigatorSelector(
location_name=graphql_data["repositoryLocationName"],
repository_name=graphql_data["repositoryName"],
name=graphql_data["name"],
Expand Down
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/core/instance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,8 +1776,8 @@ def add_instigator_state(self, state):
def update_instigator_state(self, state):
return self._schedule_storage.update_instigator_state(state)

def delete_instigator_state(self, origin_id):
return self._schedule_storage.delete_instigator_state(origin_id)
def delete_instigator_state(self, origin_id, selector_id):
return self._schedule_storage.delete_instigator_state(origin_id, selector_id)

@property
def supports_batch_tick_queries(self):
Expand Down
28 changes: 28 additions & 0 deletions python_modules/dagster/dagster/core/scheduler/instigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dagster import check
from dagster.core.definitions.run_request import InstigatorType
from dagster.core.host_representation.origin import ExternalInstigatorOrigin
from dagster.core.host_representation.selector import InstigatorSelector, RepositorySelector
from dagster.serdes import create_snapshot_id
from dagster.serdes.serdes import (
DefaultNamedTupleSerializer,
WhitelistMap,
Expand Down Expand Up @@ -197,10 +199,29 @@ def instigator_name(self):
def repository_origin_id(self):
return self.origin.external_repository_origin.get_id()

@property
def repository_selector_id(self):
return create_snapshot_id(
RepositorySelector(
self.origin.external_repository_origin.repository_location_origin.location_name,
self.origin.external_repository_origin.repository_name,
)
)

@property
def instigator_origin_id(self):
return self.origin.get_id()

@property
def selector_id(self):
return create_snapshot_id(
InstigatorSelector(
self.origin.external_repository_origin.repository_location_origin.location_name,
self.origin.external_repository_origin.repository_name,
self.origin.instigator_name,
)
)

def with_status(self, status):
check.inst_param(status, "status", InstigatorStatus)
return InstigatorState(
Expand Down Expand Up @@ -311,6 +332,10 @@ def with_origin_run(self, origin_run_id):
def instigator_origin_id(self):
return self.tick_data.instigator_origin_id

@property
def selector_id(self):
return self.tick_data.selector_id

@property
def instigator_name(self):
return self.tick_data.instigator_name
Expand Down Expand Up @@ -427,6 +452,7 @@ class TickData(
("cursor", Optional[str]),
("origin_run_ids", List[str]),
("failure_count", int),
("selector_id", Optional[str]),
],
)
):
Expand Down Expand Up @@ -464,6 +490,7 @@ def __new__(
cursor: Optional[str] = None,
origin_run_ids: Optional[List[str]] = None,
failure_count: Optional[int] = None,
selector_id: Optional[str] = None,
):
_validate_tick_args(instigator_type, status, run_ids, error, skip_reason)
return super(TickData, cls).__new__(
Expand All @@ -480,6 +507,7 @@ def __new__(
cursor=check.opt_str_param(cursor, "cursor"),
origin_run_ids=check.opt_list_param(origin_run_ids, "origin_run_ids", of_type=str),
failure_count=check.opt_int_param(failure_count, "failure_count", 0),
selector_id=check.opt_str_param(selector_id, "selector_id"),
)

def with_status(self, status, error=None, timestamp=None, failure_count=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,14 @@ def _apply_migration(self, migration_name, migration_fn, print_fn, force):
if self.has_secondary_index(migration_name):
if not force:
if print_fn:
print_fn("Skipping already reindexed summary: {}".format(migration_name))
print_fn(f"Skipping already applied data migration: {migration_name}")
return
if print_fn:
print_fn("Starting reindex: {}".format(migration_name))
print_fn(f"Starting data migration: {migration_name}")
migration_fn()(self, print_fn)
self.enable_secondary_index(migration_name)
if print_fn:
print_fn("Finished reindexing: {}".format(migration_name))
print_fn(f"Finished data migration: {migration_name}")

def reindex_events(self, print_fn=None, force=False):
"""Call this method to run any data migrations across the event_log table"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,8 @@ def _execute_data_migrations(
for migration_name, migration_fn in migrations.items():
if self.has_built_index(migration_name):
if not force_rebuild_all:
if print_fn:
print_fn(f"Skipping already applied data migration: {migration_name}")
continue
if print_fn:
print_fn(f"Starting data migration: {migration_name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def update_instigator_state(self, state: InstigatorState):
"""

@abc.abstractmethod
def delete_instigator_state(self, origin_id: str):
def delete_instigator_state(self, origin_id: str, selector_id: str):
"""Delete a state in storage.
Args:
Expand Down
106 changes: 104 additions & 2 deletions python_modules/dagster/dagster/core/storage/schedules/migration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,106 @@
from typing import Callable, Mapping

REQUIRED_SCHEDULE_DATA_MIGRATIONS: Mapping[str, Callable] = {}
OPTIONAL_SCHEDULE_DATA_MIGRATIONS: Mapping[str, Callable] = {}
import sqlalchemy as db
from tqdm import tqdm

from dagster.core.scheduler.instigation import InstigatorState
from dagster.serdes import deserialize_as

from ..schedules.schema import InstigatorsTable, JobTable, JobTickTable

SCHEDULE_JOBS_SELECTOR_ID = "schedule_jobs_selector_id"
SCHEDULE_TICKS_SELECTOR_ID = "schedule_ticks_selector_id"

REQUIRED_SCHEDULE_DATA_MIGRATIONS: Mapping[str, Callable] = {
SCHEDULE_JOBS_SELECTOR_ID: lambda: add_selector_id_to_jobs_table,
}
OPTIONAL_SCHEDULE_DATA_MIGRATIONS: Mapping[str, Callable] = {
SCHEDULE_TICKS_SELECTOR_ID: lambda: add_selector_id_to_ticks_table,
}


def add_selector_id_to_jobs_table(storage, print_fn=None):
"""
Utility method that calculates the selector_id for each stored instigator state, and writes
it to the jobs table.
"""

if print_fn:
print_fn("Querying storage.")

with storage.connect() as conn:
rows = conn.execute(
db.select(
[
JobTable.c.id,
JobTable.c.job_body,
JobTable.c.create_timestamp,
JobTable.c.update_timestamp,
]
).order_by(JobTable.c.id.asc())
).fetchall()

for (row_id, state_str, create_timestamp, update_timestamp) in tqdm(rows):
state = deserialize_as(state_str, InstigatorState)
selector_id = state.selector_id

# insert the state into a new instigator table, which has a unique constraint on
# selector_id
try:
conn.execute(
InstigatorsTable.insert().values(
selector_id=selector_id,
repository_selector_id=state.repository_selector_id,
status=state.status.value,
instigator_type=state.instigator_type.value,
instigator_body=state_str,
create_timestamp=create_timestamp,
update_timestamp=update_timestamp,
)
)
except db.exc.IntegrityError:
conn.execute(
InstigatorsTable.update()
.where(InstigatorsTable.c.selector_id == selector_id)
.values(
status=state.status.value,
repository_selector_id=state.repository_selector_id,
instigator_type=state.instigator_type.value,
instigator_body=state_str,
update_timestamp=update_timestamp,
)
)

conn.execute(
JobTable.update() # pylint: disable=no-value-for-parameter
.where(JobTable.c.id == row_id)
.where(JobTable.c.selector_id == None)
.values(selector_id=state.selector_id)
)

if print_fn:
print_fn("Complete.")


def add_selector_id_to_ticks_table(storage, print_fn=None):
"""
Utility method that calculates the selector_id for each stored instigator state, and writes
it to the jobs table.
"""

if print_fn:
print_fn("Querying storage.")

instigator_states = storage.all_instigator_state()
for state in tqdm(instigator_states):

with storage.connect() as conn:
conn.execute(
JobTickTable.update() # pylint: disable=no-value-for-parameter
.where(JobTickTable.c.job_origin_id == state.instigator_origin_id)
.where(JobTickTable.c.selector_id == None)
.values(selector_id=state.selector_id)
)

if print_fn:
print_fn("Complete.")

0 comments on commit 6262a51

Please sign in to comment.