Skip to content

Commit

Permalink
[2/3 partition status cache] Serialize version, rebuild cache for uns…
Browse files Browse the repository at this point in the history
…upported version (#11687)

This PR updates partitions subsets to also include a serialization
version number in the serialized data. This enables us to gracefully
handle older serialized data should we update the serialization content.

It also updates the partition status cache to rebuild new data when the
serialized version is unsupported.
  • Loading branch information
clairelin135 committed Jan 18, 2023
1 parent 6d1c659 commit 8c78399
Show file tree
Hide file tree
Showing 12 changed files with 433 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import itertools
import json
from datetime import datetime
from typing import Dict, List, Mapping, NamedTuple, Optional, Sequence, Tuple
from typing import Dict, Iterable, List, Mapping, NamedTuple, Optional, Sequence, Set, Tuple, cast

import dagster._check as check
from dagster._annotations import experimental
from dagster._core.errors import DagsterInvalidDefinitionError, DagsterInvalidInvocationError
from dagster._core.errors import (
DagsterInvalidDefinitionError,
DagsterInvalidInvocationError,
)
from dagster._core.storage.tags import (
MULTIDIMENSIONAL_PARTITION_PREFIX,
get_multidimensional_partition_tag,
Expand All @@ -15,8 +17,10 @@
DefaultPartitionsSubset,
Partition,
PartitionsDefinition,
PartitionsSubset,
StaticPartitionsDefinition,
)
from .time_window_partitions import TimeWindowPartitionsDefinition

INVALID_STATIC_PARTITIONS_KEY_CHARACTERS = set(["|", ",", "[", "]"])

Expand Down Expand Up @@ -232,7 +236,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"{type(self).__name__}(dimensions={[str(dim) for dim in self.partitions_defs]}"

def get_multi_partition_key_from_str(self, partition_key_str: str) -> MultiPartitionKey:
def get_partition_key_from_str(self, partition_key_str: str) -> str:
"""
Given a string representation of a partition key, returns a MultiPartitionKey object.
"""
Expand All @@ -246,43 +250,68 @@ def get_multi_partition_key_from_str(self, partition_key_str: str) -> MultiParti
f" {partition_key_str}, but got {len(partition_key_strs)}"
),
)
keys_per_dimension = [
(dim.name, dim.partitions_def.get_partition_keys()) for dim in self._partitions_defs
]

partition_key_dims_by_idx = dict(enumerate([dim.name for dim in self._partitions_defs]))
for idx, key in enumerate(partition_key_strs):
check.invariant(
key in keys_per_dimension[idx][1],
f"Partition key {key} not found in dimension {partition_key_dims_by_idx[idx][0]}",
)

multi_partition_key = MultiPartitionKey(
{partition_key_dims_by_idx[idx]: key for idx, key in enumerate(partition_key_strs)}
return MultiPartitionKey(
{dim.name: partition_key_strs[i] for i, dim in enumerate(self._partitions_defs)}
)
return multi_partition_key

def deserialize_subset(self, serialized: str) -> "MultiPartitionsSubset":
def empty_subset(self) -> "MultiPartitionsSubset":
return MultiPartitionsSubset(self, set())

def deserialize_subset(self, serialized: str) -> "PartitionsSubset":
return MultiPartitionsSubset.from_serialized(self, serialized)

def _get_primary_and_secondary_dimension(
self,
) -> Tuple[PartitionDimensionDefinition, PartitionDimensionDefinition]:
# Multipartitions subsets are serialized by primary dimension. If changing
# the selection of primary/secondary dimension, will need to also update the
# serialization of MultiPartitionsSubsets

time_dimensions = [
dim
for dim in self.partitions_defs
if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition)
]
if len(time_dimensions) == 1:
primary_dimension, secondary_dimension = time_dimensions[0], next(
iter([dim for dim in self.partitions_defs if dim != time_dimensions[0]])
)
else:
primary_dimension, secondary_dimension = (
self.partitions_defs[0],
self.partitions_defs[1],
)

return primary_dimension, secondary_dimension

@property
def primary_dimension(self) -> PartitionDimensionDefinition:
return self._get_primary_and_secondary_dimension()[0]

@property
def secondary_dimension(self) -> PartitionDimensionDefinition:
return self._get_primary_and_secondary_dimension()[1]


class MultiPartitionsSubset(DefaultPartitionsSubset):
@staticmethod
def from_serialized(
partitions_def: PartitionsDefinition, serialized: str
) -> "MultiPartitionsSubset":
if not isinstance(partitions_def, MultiPartitionsDefinition):
check.failed(
"Must pass a MultiPartitionsDefinition object to deserialize MultiPartitionsSubset."
)
def __init__(
self,
partitions_def: MultiPartitionsDefinition,
subset: Optional[Set[str]] = None,
):
check.inst_param(partitions_def, "partitions_def", MultiPartitionsDefinition)
subset = (
set(partitions_def.get_partition_key_from_str(key) for key in subset)
if subset
else set()
)
super(MultiPartitionsSubset, self).__init__(partitions_def, subset)

def with_partition_keys(self, partition_keys: Iterable[str]) -> "MultiPartitionsSubset":
return MultiPartitionsSubset(
subset=set(
[
partitions_def.get_multi_partition_key_from_str(key)
for key in json.loads(serialized)
]
),
partitions_def=partitions_def,
cast(MultiPartitionsDefinition, self._partitions_def),
self._subset | set(partition_keys),
)


Expand Down
56 changes: 43 additions & 13 deletions python_modules/dagster/dagster/_core/definitions/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ..decorator_utils import get_function_params
from ..errors import (
DagsterInvalidDefinitionError,
DagsterInvalidDeserializationVersionError,
DagsterInvalidInvocationError,
DagsterInvariantViolationError,
DagsterUnknownPartitionError,
Expand Down Expand Up @@ -1069,16 +1070,24 @@ def get_partition_key_ranges(
def with_partition_keys(self, partition_keys: Iterable[str]) -> "PartitionsSubset":
raise NotImplementedError()

@abstractmethod
def with_partition_key_range(
self, partition_key_range: PartitionKeyRange
) -> "PartitionsSubset":
raise NotImplementedError()
return self.with_partition_keys(
self.partitions_def.get_partition_keys_in_range(partition_key_range)
)

@abstractmethod
def serialize(self) -> str:
raise NotImplementedError()

@classmethod
@abstractmethod
def from_serialized(
cls, partitions_def: PartitionsDefinition, serialized: str
) -> "PartitionsSubset":
raise NotImplementedError()

@property
@abstractmethod
def partitions_def(self) -> PartitionsDefinition:
Expand All @@ -1094,6 +1103,10 @@ def __contains__(self, value) -> bool:


class DefaultPartitionsSubset(PartitionsSubset):
# Every time we change the serialization format, we should increment the version number.
# This will ensure that we can gracefully degrade when deserializing old data.
SERIALIZATION_VERSION = 1

def __init__(self, partitions_def: PartitionsDefinition, subset: Optional[Set[str]] = None):
check.opt_set_param(subset, "subset")
self._partitions_def = partitions_def
Expand All @@ -1102,7 +1115,9 @@ def __init__(self, partitions_def: PartitionsDefinition, subset: Optional[Set[st
def get_partition_keys_not_in_subset(
self, current_time: Optional[datetime] = None
) -> Iterable[str]:
return set(self._partitions_def.get_partition_keys()) - self._subset
return (
set(self._partitions_def.get_partition_keys(current_time=current_time)) - self._subset
)

def get_partition_keys(self, current_time: Optional[datetime] = None) -> Iterable[str]:
return self._subset
Expand Down Expand Up @@ -1130,7 +1145,10 @@ def get_partition_key_ranges(
return result

def with_partition_keys(self, partition_keys: Iterable[str]) -> "DefaultPartitionsSubset":
return DefaultPartitionsSubset(self._partitions_def, self._subset | set(partition_keys))
return DefaultPartitionsSubset(
self._partitions_def,
self._subset | set(partition_keys),
)

def with_partition_key_range(
self, partition_key_range: PartitionKeyRange
Expand All @@ -1140,7 +1158,27 @@ def with_partition_key_range(
)

def serialize(self) -> str:
return json.dumps(list(self._subset))
# Serialize version number, so attempting to deserialize old versions can be handled gracefully.
# Any time the serialization format changes, we should increment the version number.
return json.dumps({"version": self.SERIALIZATION_VERSION, "subset": list(self._subset)})

@classmethod
def from_serialized(
cls, partitions_def: PartitionsDefinition, serialized: str
) -> "PartitionsSubset":
# Check the version number, so only valid versions can be deserialized.
data = json.loads(serialized)

if isinstance(data, list):
# backwards compatibility
return cls(subset=set(data), partitions_def=partitions_def)
else:
if data.get("version") != cls.SERIALIZATION_VERSION:
raise DagsterInvalidDeserializationVersionError(
f"Attempted to deserialize partition subset with version {data.get('version')},"
f" but only version {cls.SERIALIZATION_VERSION} is supported."
)
return cls(subset=set(data.get("subset")), partitions_def=partitions_def)

@property
def partitions_def(self) -> PartitionsDefinition:
Expand All @@ -1159,14 +1197,6 @@ def __len__(self) -> int:
def __contains__(self, value) -> bool:
return value in self._subset

@staticmethod
def from_serialized(
partitions_def: PartitionsDefinition, serialized: str
) -> "DefaultPartitionsSubset":
return DefaultPartitionsSubset(
subset=set(json.loads(serialized)), partitions_def=partitions_def
)

def __repr__(self) -> str:
return (
f"DefaultPartitionsSubset(subset={self._subset}, partitions_def={self._partitions_def})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE
from dagster._utils.schedules import cron_string_iterator, reverse_cron_string_iterator

from ..errors import DagsterInvalidDeserializationVersionError
from .partition import (
DEFAULT_DATE_FORMAT,
Partition,
Expand Down Expand Up @@ -212,6 +213,17 @@ def time_windows_for_partition_keys(
self._iterate_time_windows(datetime.strptime(partition_key, self.fmt))
)
partition_key_time_windows.append(next(cur_windows_iterator))

end_tw = self.get_last_partition_window()
if end_tw is None:
check.failed("No end time window found")
end_timestamp = end_tw.end.timestamp()
partition_key_time_windows = [
tw
for tw in partition_key_time_windows
if tw.start.timestamp() >= self.start.timestamp()
and tw.end.timestamp() <= end_timestamp
]
return partition_key_time_windows

def start_time_for_partition_key(self, partition_key: str) -> datetime:
Expand Down Expand Up @@ -523,7 +535,7 @@ def less_than(self, partition_key1: str, partition_key2: str) -> bool:
def empty_subset(self) -> "TimeWindowPartitionsSubset":
return TimeWindowPartitionsSubset(self, [], 0)

def deserialize_subset(self, serialized: str) -> "TimeWindowPartitionsSubset":
def deserialize_subset(self, serialized: str) -> "PartitionsSubset":
return TimeWindowPartitionsSubset.from_serialized(self, serialized)

@property
Expand Down Expand Up @@ -1046,14 +1058,20 @@ def inner(fn: Callable[[datetime, datetime], Mapping[str, Any]]) -> PartitionedC


class TimeWindowPartitionsSubset(PartitionsSubset):
# Every time we change the serialization format, we should increment the version number.
# This will ensure that we can gracefully degrade when deserializing old data.
SERIALIZATION_VERSION = 1

def __init__(
self,
partitions_def: TimeWindowPartitionsDefinition,
included_time_windows: Sequence[TimeWindow],
num_partitions: int,
):
self._partitions_def = check.inst_param(
partitions_def, "partitions_def", TimeWindowPartitionsDefinition
)
check.sequence_param(included_time_windows, "included_time_windows", of_type=TimeWindow)
self._partitions_def = partitions_def
self._included_time_windows = included_time_windows
self._num_partitions = num_partitions

Expand Down Expand Up @@ -1185,10 +1203,14 @@ def with_partition_key_range(
self._partitions_def.get_partition_keys_in_range(partition_key_range)
)

@staticmethod
@classmethod
def from_serialized(
partitions_def: TimeWindowPartitionsDefinition, serialized: str
) -> "TimeWindowPartitionsSubset":
cls, partitions_def: PartitionsDefinition, serialized: str
) -> "PartitionsSubset":
if not isinstance(partitions_def, TimeWindowPartitionsDefinition):
check.failed("Partitions definition must be a TimeWindowPartitionsDefinition")
partitions_def = cast(TimeWindowPartitionsDefinition, partitions_def)

loaded = json.loads(serialized)

def tuples_to_time_windows(tuples):
Expand All @@ -1207,15 +1229,23 @@ def tuples_to_time_windows(tuples):
len(partitions_def.get_partition_keys_in_time_window(time_window))
for time_window in time_windows
)
else:
elif isinstance(loaded, dict) and (
"version" not in loaded or loaded["version"] == cls.SERIALIZATION_VERSION
): # version 1
time_windows = tuples_to_time_windows(loaded["time_windows"])
num_partitions = loaded["num_partitions"]
else:
raise DagsterInvalidDeserializationVersionError(
f"Attempted to deserialize partition subset with version {loaded.get('version')},"
f" but only version {cls.SERIALIZATION_VERSION} is supported."
)

return TimeWindowPartitionsSubset(partitions_def, time_windows, num_partitions)

def serialize(self) -> str:
return json.dumps(
{
"version": self.SERIALIZATION_VERSION,
"time_windows": [
(window.start.timestamp(), window.end.timestamp())
for window in self._included_time_windows
Expand Down
4 changes: 4 additions & 0 deletions python_modules/dagster/dagster/_core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class DagsterInvalidSubsetError(DagsterError):
"""


class DagsterInvalidDeserializationVersionError(DagsterError):
"""Indicates that a serialized value has an unsupported version and cannot be deserialized."""


CONFIG_ERROR_VERBIAGE = """
This value can be a:
- Field
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,3 +692,9 @@ def __len__(self) -> int:

def __contains__(self, value) -> bool:
raise NotImplementedError()

@classmethod
def from_serialized(
cls, partitions_def: "PartitionsDefinition", serialized: str
) -> "PartitionsSubset":
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ def _construct_asset_record_from_row(self, row, last_materialization: Optional[E
last_run_id=row[3],
asset_details=AssetDetails.from_db_string(row[4]),
cached_status=AssetStatusCacheValue.from_db_string(row[5])
if self.has_asset_key_col("cached_status_data")
if self.can_cache_asset_status_data()
else None,
),
)
Expand Down Expand Up @@ -1235,7 +1235,7 @@ def _fetch_raw_asset_rows(self, asset_keys=None, prefix=None, limit=None, cursor
def update_asset_cached_status_data(
self, asset_key: AssetKey, cache_values: "AssetStatusCacheValue"
) -> None:
if self.has_asset_key_col("cached_status_data"):
if self.can_cache_asset_status_data():
with self.index_connection() as conn:
conn.execute(
AssetKeyTable.update() # pylint: disable=no-value-for-parameter
Expand Down Expand Up @@ -1545,7 +1545,7 @@ def _get_asset_key_values_on_wipe(self):
wipe_timestamp=utc_datetime_from_timestamp(wipe_timestamp),
)
)
if self.has_asset_key_col("cached_status_data"):
if self.can_cache_asset_status_data():
values.update(dict(cached_status_data=None))
return values

Expand Down

0 comments on commit 8c78399

Please sign in to comment.