Skip to content

Commit

Permalink
enable cross-job asset partitions (#6865)
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed Mar 2, 2022
1 parent 9b08239 commit cbc48e8
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from dagster import check
from dagster.core.definitions.events import AssetKey
from dagster.core.definitions.partition import PartitionsDefinition
from dagster.core.definitions.partition_key_range import PartitionKeyRange

from .asset import AssetsDefinition


def get_upstream_partitions_for_partition_range(
downstream_assets_def: AssetsDefinition,
upstream_assets_def: AssetsDefinition,
upstream_partitions_def: PartitionsDefinition,
upstream_asset_key: AssetKey,
downstream_partition_key_range: PartitionKeyRange,
) -> PartitionKeyRange:
Expand All @@ -18,14 +19,14 @@ def get_upstream_partitions_for_partition_range(
if downstream_assets_def.partitions_def is None:
check.failed("downstream asset is not partitioned")

if upstream_assets_def.partitions_def is None:
if upstream_partitions_def is None:
check.failed("upstream asset is not partitioned")

downstream_partition_mapping = downstream_assets_def.get_partition_mapping(upstream_asset_key)
return downstream_partition_mapping.get_upstream_partitions_for_partition_range(
downstream_partition_key_range,
downstream_assets_def.partitions_def,
upstream_assets_def.partitions_def,
upstream_partitions_def,
)


Expand Down
27 changes: 16 additions & 11 deletions python_modules/dagster/dagster/core/asset_defs/assets_job.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from typing import AbstractSet, Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast

from dagster import check
Expand Down Expand Up @@ -82,7 +83,7 @@ def asset2(asset1):

op_defs = build_op_deps(assets, source_assets_by_key.keys())
root_manager = build_root_manager(source_assets_by_key)
partitioned_config = build_job_partitions_from_assets(assets)
partitioned_config = build_job_partitions_from_assets(assets, source_assets or [])

return GraphDefinition(
name=name,
Expand All @@ -104,13 +105,14 @@ def asset2(asset1):

def build_job_partitions_from_assets(
assets: Sequence[AssetsDefinition],
source_assets: Sequence[Union[SourceAsset, AssetsDefinition]],
) -> Optional[PartitionedConfig]:
assets_with_partitions_defs = [assets_def for assets_def in assets if assets_def.partitions_def]

if len(assets_with_partitions_defs) == 0:
return None

first_assets_with_partitions_def = assets_with_partitions_defs[0]
first_assets_with_partitions_def: AssetsDefinition = assets_with_partitions_defs[0]
for assets_def in assets_with_partitions_defs:
if assets_def.partitions_def != first_assets_with_partitions_def.partitions_def:
first_asset_key = next(iter(assets_def.asset_keys)).to_string()
Expand All @@ -121,9 +123,14 @@ def build_job_partitions_from_assets(
f"'{second_asset_key}' have different partitions definitions. "
)

assets_defs_by_asset_key = {
asset_key: assets_def for assets_def in assets for asset_key in assets_def.asset_keys
}
partitions_defs_by_asset_key: Dict[AssetKey, PartitionsDefinition] = {}
asset: Union[AssetsDefinition, SourceAsset]
for asset in itertools.chain.from_iterable([assets, source_assets]):
if isinstance(asset, AssetsDefinition) and asset.partitions_def is not None:
for asset_key in asset.asset_keys:
partitions_defs_by_asset_key[asset_key] = asset.partitions_def
elif isinstance(asset, SourceAsset) and asset.partitions_def is not None:
partitions_defs_by_asset_key[asset.key] = asset.partitions_def

def asset_partitions_for_job_partition(
job_partition_key: str,
Expand Down Expand Up @@ -151,13 +158,10 @@ def run_config_for_partition_fn(partition_key: str) -> Dict[str, Any]:

inputs_dict: Dict[str, Dict[str, Any]] = {}
for in_asset_key, input_def in assets_def.input_defs_by_asset_key.items():
upstream_assets_def = assets_defs_by_asset_key[in_asset_key]
if (
assets_def.partitions_def is not None
and upstream_assets_def.partitions_def is not None
):
upstream_partitions_def = partitions_defs_by_asset_key[in_asset_key]
if assets_def.partitions_def is not None and upstream_partitions_def is not None:
upstream_partition_key_range = get_upstream_partitions_for_partition_range(
assets_def, upstream_assets_def, in_asset_key, asset_partition_key_range
assets_def, upstream_partitions_def, in_asset_key, asset_partition_key_range
)
inputs_dict[input_def.name] = {
"start": upstream_partition_key_range.start,
Expand Down Expand Up @@ -256,6 +260,7 @@ def _op():
dagster_type=input_context.dagster_type,
upstream_output=output_context,
op_def=input_context.op_def,
step_context=input_context.step_context,
)

io_manager = getattr(cast(Any, input_context.resources), source_asset.io_manager_key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def asset(
the asset, e.g. "dbt" or "spark". It will be displayed in Dagit as a badge on the asset.
dagster_type (Optional[DagsterType]): Allows specifying type validation functions that
will be executed on the output of the decorated function after it runs.
partitions_def (Optional[PartitionsDefiniition]): Defines the set of partition keys that
partitions_def (Optional[PartitionsDefinition]): Defines the set of partition keys that
compose the asset.
partition_mappings (Optional[Mapping[str, PartitionMapping]]): Defines how to map partition
keys for this asset to partition keys of upstream assets. Each key in the dictionary
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, NamedTuple, Optional

from dagster.core.definitions.events import AssetKey
from dagster.core.definitions.partition import PartitionsDefinition


class SourceAsset(NamedTuple):
Expand All @@ -13,9 +14,12 @@ class SourceAsset(NamedTuple):
io_manager_key (str): The key for the IOManager that will be used to load the contents of
the asset when it's used as an input to other assets inside a job.
description (Optional[str]): The description of the asset.
partitions_def (Optional[PartitionsDefinition]): Defines the set of partition keys that
compose the asset.
"""

key: AssetKey
metadata: Optional[Any] = None
io_manager_key: str = "io_manager"
description: Optional[str] = None
partitions_def: Optional[PartitionsDefinition] = None
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def build_input_context(
resource_config: Optional[Dict[str, Any]] = None,
resources: Optional[Dict[str, Any]] = None,
op_def: Optional[OpDefinition] = None,
step_context: Optional["StepExecutionContext"] = None,
) -> "InputContext":
"""Builds input context from provided parameters.
Expand All @@ -320,6 +321,7 @@ def build_input_context(
definition.
asset_key (Optional[AssetKey]): The asset key attached to the InputDefinition.
op_def (Optional[OpDefinition]): The definition of the op that's loading the input.
step_context (Optional[StepExecutionContext]): For internal use.
Examples:
Expand All @@ -331,6 +333,7 @@ def build_input_context(
do_something
"""
from dagster.core.execution.context.output import OutputContext
from dagster.core.execution.context.system import StepExecutionContext
from dagster.core.execution.context_creation_pipeline import initialize_console_manager
from dagster.core.types.dagster_type import DagsterType

Expand All @@ -341,6 +344,7 @@ def build_input_context(
resource_config = check.opt_dict_param(resource_config, "resource_config", key_type=str)
resources = check.opt_dict_param(resources, "resources", key_type=str)
op_def = check.opt_inst_param(op_def, "op_def", OpDefinition)
step_context = check.opt_inst_param(step_context, "step_context", StepExecutionContext)

return InputContext(
name=name,
Expand All @@ -352,6 +356,6 @@ def build_input_context(
log_manager=initialize_console_manager(None),
resource_config=resource_config,
resources=resources,
step_context=None,
step_context=step_context,
op_def=op_def,
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
AssetMaterialization,
DagsterInvalidDefinitionError,
DailyPartitionsDefinition,
HourlyPartitionsDefinition,
IOManager,
IOManagerDefinition,
PartitionsDefinition,
SourceAsset,
StaticPartitionsDefinition,
)
from dagster.core.asset_defs import asset, build_assets_job
Expand Down Expand Up @@ -378,3 +380,58 @@ def downstream_asset(upstream_asset):
resource_defs={"io_manager": IOManagerDefinition.hardcoded_io_manager(MyIOManager())},
)
my_job.execute_in_process(partition_key="2020-01-02")


def test_cross_job_different_partitions():
@asset(partitions_def=HourlyPartitionsDefinition(start_date="2021-05-05-00:00"))
def hourly_asset():
pass

@asset(partitions_def=DailyPartitionsDefinition(start_date="2021-05-05"))
def daily_asset(hourly_asset):
assert hourly_asset is None

class CustomIOManager(IOManager):
def handle_output(self, context, obj):
pass

def load_input(self, context):
key_range = context.asset_partition_key_range
assert key_range.start == "2021-06-06-00:00"
assert key_range.end == "2021-06-06-23:00"

daily_job = build_assets_job(
name="daily_job",
assets=[daily_asset],
source_assets=[hourly_asset],
resource_defs={"io_manager": IOManagerDefinition.hardcoded_io_manager(CustomIOManager())},
)
assert daily_job.execute_in_process(partition_key="2021-06-06").success


def test_source_asset_partitions():
hourly_asset = SourceAsset(
AssetKey("hourly_asset"),
partitions_def=HourlyPartitionsDefinition(start_date="2021-05-05-00:00"),
)

@asset(partitions_def=DailyPartitionsDefinition(start_date="2021-05-05"))
def daily_asset(hourly_asset):
assert hourly_asset is None

class CustomIOManager(IOManager):
def handle_output(self, context, obj):
pass

def load_input(self, context):
key_range = context.asset_partition_key_range
assert key_range.start == "2021-06-06-00:00"
assert key_range.end == "2021-06-06-23:00"

daily_job = build_assets_job(
name="daily_job",
assets=[daily_asset],
source_assets=[hourly_asset],
resource_defs={"io_manager": IOManagerDefinition.hardcoded_io_manager(CustomIOManager())},
)
assert daily_job.execute_in_process(partition_key="2021-06-06").success
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

from dagster import (
AssetKey,
MetadataValue,
Out,
Output,
SolidExecutionContext,
TableColumn,
TableSchema,
check,
get_dagster_logger,
MetadataValue,
)
from dagster.core.asset_defs import AssetsDefinition, multi_asset

Expand Down

0 comments on commit cbc48e8

Please sign in to comment.