Skip to content

Commit

Permalink
Allow setting logical version inside op (#12189)
Browse files Browse the repository at this point in the history
### Summary & Motivation

Allow setting the `logical_version` via `Output` object. This will
override the auto-generated logical version.

### How I Tested These Changes

New unit test
  • Loading branch information
smackesey committed Feb 23, 2023
1 parent e3c8825 commit 2d43d39
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 94 deletions.
15 changes: 15 additions & 0 deletions python_modules/dagster/dagster/_core/definitions/events.py
Expand Up @@ -20,8 +20,10 @@
import dagster._check as check
import dagster._seven as seven
from dagster._annotations import PublicAttr, public
from dagster._core.definitions.logical_version import LogicalVersion
from dagster._core.storage.tags import MULTIDIMENSIONAL_PARTITION_PREFIX, SYSTEM_TAG_PREFIX
from dagster._serdes import DefaultNamedTupleSerializer, whitelist_for_serdes
from dagster._utils.backcompat import experimental_class_param_warning

from .metadata import (
MetadataEntry,
Expand Down Expand Up @@ -234,6 +236,8 @@ class Output(Generic[T]):
Arbitrary metadata about the failure. Keys are displayed string labels, and values are
one of the following: string, float, int, JSON-serializable dict, JSON-serializable
list, and one of the data classes returned by a MetadataValue static method.
logical_version (Optional[LogicalVersion]): (Experimental) A logical version to manually set
for the asset.
"""

def __init__(
Expand All @@ -242,6 +246,7 @@ def __init__(
output_name: Optional[str] = DEFAULT_OUTPUT,
metadata_entries: Optional[Sequence[Union[MetadataEntry, PartitionMetadataEntry]]] = None,
metadata: Optional[Mapping[str, RawMetadataValue]] = None,
logical_version: Optional[LogicalVersion] = None,
):
metadata = check.opt_mapping_param(metadata, "metadata", key_type=str)
metadata_entries = check.opt_sequence_param(
Expand All @@ -252,6 +257,11 @@ def __init__(
self._value = value
self._output_name = check.str_param(output_name, "output_name")
self._metadata_entries = normalize_metadata(metadata, metadata_entries)
if logical_version is not None:
experimental_class_param_warning("logical_version", "Output")
self._logical_version = check.opt_inst_param(
logical_version, "logical_version", LogicalVersion
)

@property
def metadata_entries(self) -> Sequence[Union[PartitionMetadataEntry, MetadataEntry]]:
Expand All @@ -267,6 +277,11 @@ def value(self) -> Any:
def output_name(self) -> str:
return self._output_name

@public
@property
def logical_version(self) -> Optional[LogicalVersion]:
return self._logical_version

def __eq__(self, other: object) -> bool:
return (
isinstance(other, Output)
Expand Down
27 changes: 18 additions & 9 deletions python_modules/dagster/dagster/_core/execution/context/system.py
Expand Up @@ -24,10 +24,6 @@
from dagster._annotations import public
from dagster._core.definitions.events import AssetKey, AssetLineageInfo
from dagster._core.definitions.hook_definition import HookDefinition
from dagster._core.definitions.logical_version import (
LogicalVersion,
extract_logical_version_from_entry,
)
from dagster._core.definitions.mode import ModeDefinition
from dagster._core.definitions.op_definition import OpDefinition
from dagster._core.definitions.partition import PartitionsDefinition, PartitionsSubset
Expand Down Expand Up @@ -67,6 +63,9 @@
if TYPE_CHECKING:
from dagster._core.definitions.dependency import Node, NodeHandle
from dagster._core.definitions.job_definition import JobDefinition
from dagster._core.definitions.logical_version import (
LogicalVersion,
)
from dagster._core.definitions.resource_definition import Resources
from dagster._core.event_api import EventLogRecord
from dagster._core.execution.plan.plan import ExecutionPlan
Expand Down Expand Up @@ -488,7 +487,7 @@ def __init__(

self._input_asset_records: Dict[AssetKey, Optional["EventLogRecord"]] = {}
self._is_external_input_asset_records_loaded = False
self._generated_logical_versions: Dict[AssetKey, LogicalVersion] = {}
self._logical_version_cache: Dict[AssetKey, "LogicalVersion"] = {}

@property
def step(self) -> ExecutionStep:
Expand Down Expand Up @@ -822,8 +821,14 @@ def step_materializes_assets(self) -> bool:
)
return asset_info is not None

def record_logical_version(self, asset_key: AssetKey, logical_version: LogicalVersion) -> None:
self._generated_logical_versions[asset_key] = logical_version
def set_logical_version(self, asset_key: AssetKey, logical_version: "LogicalVersion") -> None:
self._logical_version_cache[asset_key] = logical_version

def has_logical_version(self, asset_key: AssetKey) -> bool:
return asset_key in self._logical_version_cache

def get_logical_version(self, asset_key: AssetKey) -> "LogicalVersion":
return self._logical_version_cache[asset_key]

@property
def input_asset_records(self) -> Optional[Mapping[AssetKey, Optional["EventLogRecord"]]]:
Expand Down Expand Up @@ -865,12 +870,16 @@ def fetch_external_input_asset_records(self) -> None:
self._is_external_input_asset_records_loaded = True

def _fetch_input_asset_record(self, key: AssetKey, retries: int = 0) -> None:
from dagster._core.definitions.logical_version import (
extract_logical_version_from_entry,
)

event = self.instance.get_latest_logical_version_record(key)
if key in self._generated_logical_versions and retries <= 5:
if key in self._logical_version_cache and retries <= 5:
event_logical_version = (
None if event is None else extract_logical_version_from_entry(event.event_log_entry)
)
if event_logical_version == self._generated_logical_versions[key]:
if event_logical_version == self._logical_version_cache[key]:
self._input_asset_records[key] = event
else:
self._fetch_input_asset_record(key, retries + 1)
Expand Down
Expand Up @@ -251,16 +251,16 @@ def validate_and_coerce_op_result_to_iterator(
f" output '{output_def.name}' which does not have an Output annotation."
f" Annotation has type {annotation}."
)
output = cast(Output, element)
_check_output_object_name(output, output_def, position)
_check_output_object_name(element, output_def, position)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)

yield Output(
output_name=output_def.name,
value=output.value,
metadata_entries=output.metadata_entries,
value=element.value,
metadata_entries=element.metadata_entries,
logical_version=element.logical_version,
)
else:
# If annotation indicates a generic output annotation, and an
Expand Down
66 changes: 52 additions & 14 deletions python_modules/dagster/dagster/_core/execution/plan/execute_step.py
Expand Up @@ -6,13 +6,16 @@
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
cast,
)

from typing_extensions import TypedDict

import dagster._check as check
from dagster._core.definitions import (
AssetKey,
Expand Down Expand Up @@ -133,6 +136,7 @@ def _step_output_error_checked_user_event_sequence(
*output.metadata_entries,
*normalize_metadata(cast(Dict[str, Any], metadata), []),
],
logical_version=output.logical_version,
)
else:
if not output_def.is_dynamic:
Expand Down Expand Up @@ -474,9 +478,21 @@ def _get_output_asset_materializations(
step_context.is_external_input_asset_records_loaded
and asset_key in step_context.pipeline_def.asset_layer.asset_keys
):
tags = _build_logical_version_tags(asset_key, step_context)
logical_version = LogicalVersion(tags[LOGICAL_VERSION_TAG_KEY])
step_context.record_logical_version(asset_key, logical_version)
assert isinstance(output, Output)
code_version = _get_code_version(asset_key, step_context)
input_provenance_data = _get_input_provenance_data(asset_key, step_context)
logical_version = (
compute_logical_version(
code_version,
{k: meta["logical_version"] for k, meta in input_provenance_data.items()},
)
if output.logical_version is None
else output.logical_version
)
tags = _build_logical_version_tags(logical_version, code_version, input_provenance_data)
if not step_context.has_logical_version(asset_key):
logical_version = LogicalVersion(tags[LOGICAL_VERSION_TAG_KEY])
step_context.set_logical_version(asset_key, logical_version)
else:
tags = {}

Expand Down Expand Up @@ -538,14 +554,22 @@ def _get_output_asset_materializations(
)


def _build_logical_version_tags(
def _get_code_version(asset_key: AssetKey, step_context: StepExecutionContext) -> str:
return (
step_context.pipeline_def.asset_layer.code_version_for_asset(asset_key)
or step_context.dagster_run.run_id
)


class _InputProvenanceData(TypedDict):
logical_version: LogicalVersion
storage_id: Optional[int]


def _get_input_provenance_data(
asset_key: AssetKey, step_context: StepExecutionContext
) -> Dict[str, str]:
asset_layer = step_context.pipeline_def.asset_layer
code_version = asset_layer.code_version_for_asset(asset_key) or step_context.dagster_run.run_id
input_logical_versions: Dict[AssetKey, LogicalVersion] = {}
tags: Dict[str, str] = {}
tags[CODE_VERSION_TAG_KEY] = code_version
) -> Mapping[AssetKey, _InputProvenanceData]:
input_provenance: Dict[AssetKey, _InputProvenanceData] = {}
deps = step_context.pipeline_def.asset_layer.upstream_assets_for_asset(asset_key)
for key in deps:
# For deps external to this step, this will retrieve the cached record that was stored prior
Expand All @@ -560,11 +584,25 @@ def _build_logical_version_tags(
)
else:
logical_version = DEFAULT_LOGICAL_VERSION
input_logical_versions[key] = logical_version
tags[get_input_logical_version_tag_key(key)] = logical_version.value
tags[get_input_event_pointer_tag_key(key)] = str(event.storage_id) if event else "NULL"
input_provenance[key] = {
"logical_version": logical_version,
"storage_id": event.storage_id if event else None,
}
return input_provenance

logical_version = compute_logical_version(code_version, input_logical_versions)

def _build_logical_version_tags(
logical_version: LogicalVersion,
code_version: str,
input_provenance_data: Mapping[AssetKey, _InputProvenanceData],
) -> Dict[str, str]:
tags: Dict[str, str] = {}
tags[CODE_VERSION_TAG_KEY] = code_version
for key, meta in input_provenance_data.items():
tags[get_input_logical_version_tag_key(key)] = meta["logical_version"].value
tags[get_input_event_pointer_tag_key(key)] = (
str(meta["storage_id"]) if meta["storage_id"] else "NULL"
)
tags[LOGICAL_VERSION_TAG_KEY] = logical_version.value
return tags

Expand Down

0 comments on commit 2d43d39

Please sign in to comment.