Skip to content

Commit

Permalink
add data migration to bulk actions table for backfill jobs (#8153)
Browse files Browse the repository at this point in the history
* add data migration, write path for new bulk action columns

* extract bulk actions type

* fix typo
  • Loading branch information
prha committed Jun 8, 2022
1 parent 872c2a8 commit d759550
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 13 deletions.
4 changes: 4 additions & 0 deletions python_modules/dagster/dagster/core/execution/backfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def __new__(
check.opt_inst_param(error, "error", SerializableErrorInfo),
)

@property
def selector_id(self):
return self.partition_set_origin.get_selector_id()

def with_status(self, status):
check.inst_param(status, "status", BulkActionStatus)
return PartitionBackfill(
Expand Down
8 changes: 8 additions & 0 deletions python_modules/dagster/dagster/core/execution/bulk_actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from enum import Enum

from dagster.serdes import whitelist_for_serdes


@whitelist_for_serdes
class BulkActionType(Enum):
PARTITION_BACKFILL = "PARTITION_BACKFILL"
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from dagster.serdes.serdes import WhitelistMap, unpack_inner_value

from .selector import RepositorySelector
from .selector import PartitionSetSelector, RepositorySelector

if TYPE_CHECKING:
from dagster.core.host_representation.repository_location import (
Expand Down Expand Up @@ -529,3 +529,12 @@ def __new__(cls, external_repository_origin: ExternalRepositoryOrigin, partition

def get_id(self) -> str:
return create_snapshot_id(self)

def get_selector_id(self) -> str:
return create_snapshot_id(
PartitionSetSelector(
self.external_repository_origin.repository_location_origin.location_name,
self.external_repository_origin.repository_name,
self.partition_set_name,
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,30 @@ def to_graphql_input(self):
"repositoryName": self.repository_name,
"graphName": self.graph_name,
}


@whitelist_for_serdes
class PartitionSetSelector(
NamedTuple(
"_PartitionSetSelector",
[("location_name", str), ("repository_name", str), ("partition_set_name", str)],
)
):
"""
The information needed to resolve a partition set within a host process.
"""

def __new__(cls, location_name: str, repository_name: str, partition_set_name: str):
return super(PartitionSetSelector, cls).__new__(
cls,
location_name=check.str_param(location_name, "location_name"),
repository_name=check.str_param(repository_name, "repository_name"),
partition_set_name=check.str_param(partition_set_name, "partition_set_name"),
)

def to_graphql_input(self):
return {
"repositoryLocationName": self.location_name,
"repositoryName": self.repository_name,
"partitionSetName": self.partition_set_name,
}
60 changes: 54 additions & 6 deletions python_modules/dagster/dagster/core/storage/runs/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,30 @@
import dagster._check as check
from dagster.serdes import deserialize_as

from ...execution.backfill import PartitionBackfill
from ...execution.bulk_actions import BulkActionType
from ..pipeline_run import PipelineRun, PipelineRunStatus
from ..runs.base import RunStorage
from ..runs.schema import RunTagsTable, RunsTable
from ..runs.schema import BulkActionsTable, RunTagsTable, RunsTable
from ..tags import PARTITION_NAME_TAG, PARTITION_SET_TAG, REPOSITORY_LABEL_TAG

RUN_PARTITIONS = "run_partitions"
RUN_START_END = "run_start_end_overwritten" # was run_start_end, but renamed to overwrite bad timestamps written
RUN_REPO_LABEL_TAGS = "run_repo_label_tags"
BULK_ACTION_TYPES = "bulk_action_types"

# for `dagster instance migrate`, paired with schema changes
REQUIRED_DATA_MIGRATIONS = {
RUN_PARTITIONS: lambda: migrate_run_partition,
RUN_REPO_LABEL_TAGS: lambda: migrate_run_repo_tags,
BULK_ACTION_TYPES: lambda: migrate_bulk_actions,
}
# for `dagster instance reindex`, optionally run for better read performance
OPTIONAL_DATA_MIGRATIONS = {
RUN_START_END: lambda: migrate_run_start_end,
}

RUN_CHUNK_SIZE = 100
CHUNK_SIZE = 100

UNSTARTED_RUN_STATUSES = {
PipelineRunStatus.QUEUED,
Expand All @@ -35,7 +39,7 @@
}


def chunked_run_iterator(storage, print_fn=None, chunk_size=RUN_CHUNK_SIZE):
def chunked_run_iterator(storage, print_fn=None, chunk_size=CHUNK_SIZE):
with ExitStack() as stack:
if print_fn:
run_count = storage.get_runs_count()
Expand All @@ -58,7 +62,7 @@ def chunked_run_iterator(storage, print_fn=None, chunk_size=RUN_CHUNK_SIZE):
progress.update(len(chunk)) # pylint: disable=no-member


def chunked_run_records_iterator(storage, print_fn=None, chunk_size=RUN_CHUNK_SIZE):
def chunked_run_records_iterator(storage, print_fn=None, chunk_size=CHUNK_SIZE):
with ExitStack() as stack:
if print_fn:
run_count = storage.get_runs_count()
Expand Down Expand Up @@ -165,7 +169,7 @@ def migrate_run_repo_tags(run_storage: RunStorage, print_fn=None):
)
.where(subquery.c.tags_run_id == None)
.order_by(db.asc(RunsTable.c.id))
.limit(RUN_CHUNK_SIZE)
.limit(CHUNK_SIZE)
)

cursor = None
Expand All @@ -181,7 +185,7 @@ def migrate_run_repo_tags(run_storage: RunStorage, print_fn=None):
rows = result_proxy.fetchall()
result_proxy.close()

has_more = len(rows) >= RUN_CHUNK_SIZE
has_more = len(rows) >= CHUNK_SIZE
for row in rows:
run = deserialize_as(row[0], PipelineRun)
cursor = row[1]
Expand All @@ -205,3 +209,47 @@ def write_repo_tag(conn, run: PipelineRun):
except db.exc.IntegrityError:
# tag already exists, swallow
pass


def migrate_bulk_actions(run_storage: RunStorage, print_fn=None):
from dagster.core.storage.runs.sql_run_storage import SqlRunStorage

if not isinstance(run_storage, SqlRunStorage):
return

if print_fn:
print_fn("Querying run storage.")

base_query = (
db.select([BulkActionsTable.c.body, BulkActionsTable.c.id])
.where(BulkActionsTable.c.action_type == None)
.order_by(db.asc(BulkActionsTable.c.id))
.limit(CHUNK_SIZE)
)

cursor = None
has_more = True
while has_more:
if cursor:
query = base_query.where(BulkActionsTable.c.id > cursor)
else:
query = base_query

with run_storage.connect() as conn:
result_proxy = conn.execute(query)
rows = result_proxy.fetchall()
result_proxy.close()

has_more = len(rows) >= CHUNK_SIZE
for row in rows:
backfill = deserialize_as(row[0], PartitionBackfill)
storage_id = row[1]
conn.execute(
BulkActionsTable.update()
.values(
selector_id=backfill.selector_id,
action_type=BulkActionType.PARTITION_BACKFILL.value,
)
.where(BulkActionsTable.c.id == storage_id)
)
cursor = storage_id
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from dagster.core.events import EVENT_TYPE_TO_PIPELINE_RUN_STATUS, DagsterEvent, DagsterEventType
from dagster.core.execution.backfill import BulkActionStatus, PartitionBackfill
from dagster.core.execution.bulk_actions import BulkActionType
from dagster.core.snap import (
ExecutionPlanSnapshot,
PipelineSnapshot,
Expand Down Expand Up @@ -941,6 +942,13 @@ def has_run_stats_index_cols(self):
column_names = [x.get("name") for x in db.inspect(conn).get_columns(RunsTable.name)]
return "start_time" in column_names and "end_time" in column_names

def has_bulk_actions_selector_cols(self):
with self.connect() as conn:
column_names = [
x.get("name") for x in db.inspect(conn).get_columns(BulkActionsTable.name)
]
return "selector_id" in column_names

# Daemon heartbeats

def add_daemon_heartbeat(self, daemon_heartbeat: DaemonHeartbeat):
Expand Down Expand Up @@ -1019,14 +1027,20 @@ def get_backfill(self, backfill_id: str) -> Optional[PartitionBackfill]:

def add_backfill(self, partition_backfill: PartitionBackfill):
check.inst_param(partition_backfill, "partition_backfill", PartitionBackfill)
values = dict(
key=partition_backfill.backfill_id,
status=partition_backfill.status.value,
timestamp=utc_datetime_from_timestamp(partition_backfill.backfill_timestamp),
body=serialize_dagster_namedtuple(partition_backfill),
)

if self.has_bulk_actions_selector_cols():
values["selector_id"] = partition_backfill.selector_id
values["action_type"] = BulkActionType.PARTITION_BACKFILL.value

with self.connect() as conn:
conn.execute(
BulkActionsTable.insert().values( # pylint: disable=no-value-for-parameter
key=partition_backfill.backfill_id,
status=partition_backfill.status.value,
timestamp=utc_datetime_from_timestamp(partition_backfill.backfill_timestamp),
body=serialize_dagster_namedtuple(partition_backfill),
)
BulkActionsTable.insert().values(**values) # pylint: disable=no-value-for-parameter
)

def update_backfill(self, partition_backfill: PartitionBackfill):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from gzip import GzipFile
from typing import NamedTuple, Optional, Union

import pendulum
import pytest
import sqlalchemy as db

Expand All @@ -18,6 +19,7 @@
from dagster.core.definitions.dependency import NodeHandle
from dagster.core.events import DagsterEvent
from dagster.core.events.log import EventLogEntry
from dagster.core.execution.backfill import BulkActionStatus, PartitionBackfill
from dagster.core.instance import DagsterInstance, InstanceRef
from dagster.core.scheduler.instigation import InstigatorState, InstigatorTick
from dagster.core.storage.event_log.migration import migrate_event_log_data
Expand Down Expand Up @@ -905,6 +907,13 @@ def test_repo_label_tag_migration():


def test_add_bulk_actions_columns():
from dagster.core.host_representation.origin import (
ExternalPartitionSetOrigin,
ExternalRepositoryOrigin,
GrpcServerRepositoryLocationOrigin,
)
from dagster.core.storage.runs.schema import BulkActionsTable

src_dir = file_relative_path(__file__, "snapshot_0_14_16_bulk_actions_columns/sqlite")

with copy_directory(src_dir) as test_dir:
Expand All @@ -931,6 +940,46 @@ def test_add_bulk_actions_columns():
assert "idx_bulk_actions_action_type" in get_sqlite3_indexes(db_path, "bulk_actions")
assert "idx_bulk_actions_selector_id" in get_sqlite3_indexes(db_path, "bulk_actions")

# check data migration
backfill_count = len(instance.get_backfills())
migrated_row_count = instance._run_storage.fetchone(
db.select([db.func.count()])
.select_from(BulkActionsTable)
.where(BulkActionsTable.c.selector_id.isnot(None))
)[0]
assert migrated_row_count > 0
assert backfill_count == migrated_row_count

# check that we are writing to selector id, action types
external_origin = ExternalPartitionSetOrigin(
external_repository_origin=ExternalRepositoryOrigin(
repository_location_origin=GrpcServerRepositoryLocationOrigin(
port=1234, host="localhost"
),
repository_name="fake_repository",
),
partition_set_name="fake",
)
instance.add_backfill(
PartitionBackfill(
backfill_id="simple",
partition_set_origin=external_origin,
status=BulkActionStatus.REQUESTED,
partition_names=["one", "two", "three"],
from_failure=False,
reexecution_steps=None,
tags=None,
backfill_timestamp=pendulum.now().timestamp(),
)
)
unmigrated_row_count = instance._run_storage.fetchone(
db.select([db.func.count()])
.select_from(BulkActionsTable)
.where(BulkActionsTable.c.selector_id == None)
)[0]
assert unmigrated_row_count == 0

# test downgrade
instance._run_storage._alembic_downgrade(rev="721d858e1dda")

assert get_current_alembic_version(db_path) == "721d858e1dda"
Expand Down

0 comments on commit d759550

Please sign in to comment.