Skip to content

Commit

Permalink
[assets] add partition_def and config to define_asset_job (#8282)
Browse files Browse the repository at this point in the history
Co-authored-by: Sandy Ryza <sandy@elementl.com>
  • Loading branch information
OwenKephart and sryza committed Jun 9, 2022
1 parent c3f4d36 commit d684848
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 129 deletions.
1 change: 1 addition & 0 deletions python_modules/dagster/dagster/core/asset_defs/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def to_source_assets(self) -> Sequence[SourceAsset]:
io_manager_key=output_def.io_manager_key,
description=output_def.description,
resource_defs=self.resource_defs,
partitions_def=self.partitions_def,
)
)

Expand Down
87 changes: 10 additions & 77 deletions python_modules/dagster/dagster/core/asset_defs/assets_job.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from collections import defaultdict
from typing import (
AbstractSet,
Expand All @@ -12,13 +11,11 @@
Set,
Tuple,
Union,
cast,
)

from toposort import CircularDependencyError, toposort

import dagster._check as check
from dagster.config import Shape
from dagster.core.definitions.asset_layer import AssetLayer
from dagster.core.definitions.config import ConfigMapping
from dagster.core.definitions.dependency import (
Expand All @@ -33,15 +30,13 @@
from dagster.core.definitions.job_definition import JobDefinition
from dagster.core.definitions.output import OutputDefinition
from dagster.core.definitions.partition import PartitionedConfig, PartitionsDefinition
from dagster.core.definitions.partition_key_range import PartitionKeyRange
from dagster.core.definitions.resource_definition import ResourceDefinition
from dagster.core.errors import DagsterInvalidDefinitionError
from dagster.core.execution.with_resources import with_resources
from dagster.core.selector.subset_selector import AssetSelectionData
from dagster.utils import merge_dicts
from dagster.utils.backcompat import experimental

from .asset_partitions import get_upstream_partitions_for_partition_range
from .assets import AssetsDefinition
from .source_asset import SourceAsset

Expand All @@ -56,6 +51,7 @@ def build_assets_job(
config: Optional[Union[ConfigMapping, Dict[str, Any], PartitionedConfig]] = None,
tags: Optional[Dict[str, Any]] = None,
executor_def: Optional[ExecutorDefinition] = None,
partitions_def: Optional[PartitionsDefinition] = None,
_asset_selection_data: Optional[AssetSelectionData] = None,
) -> JobDefinition:
"""Builds a job that materializes the given assets.
Expand Down Expand Up @@ -99,17 +95,19 @@ def asset2(asset1):
)
check.opt_str_param(description, "description")
check.opt_inst_param(_asset_selection_data, "_asset_selection_data", AssetSelectionData)

# figure out what partitions (if any) exist for this job
partitions_def = partitions_def or build_job_partitions_from_assets(assets)

resource_defs = check.opt_mapping_param(resource_defs, "resource_defs")
resource_defs = merge_dicts({"io_manager": default_job_io_manager}, resource_defs)

assets = with_resources(assets, resource_defs)
source_assets = with_resources(source_assets, resource_defs)

source_assets_by_key = build_source_assets_by_key(source_assets)

partitioned_config = build_job_partitions_from_assets(assets, source_assets or [])

deps, assets_defs_by_node_handle = build_deps(assets, source_assets_by_key.keys())

# attempt to resolve cycles using multi-asset subsetting
if _has_cycles(deps):
assets = _attempt_resolve_cycles(assets)
Expand Down Expand Up @@ -141,18 +139,18 @@ def asset2(asset1):

return graph.to_job(
resource_defs=all_resource_defs,
config=config or partitioned_config,
config=config,
tags=tags,
executor_def=executor_def,
partitions_def=partitions_def,
asset_layer=asset_layer,
_asset_selection_data=_asset_selection_data,
)


def build_job_partitions_from_assets(
assets: Iterable[AssetsDefinition],
source_assets: Sequence[Union[SourceAsset, AssetsDefinition]],
) -> Optional[PartitionedConfig]:
) -> Optional[PartitionsDefinition]:
assets_with_partitions_defs = [assets_def for assets_def in assets if assets_def.partitions_def]

if len(assets_with_partitions_defs) == 0:
Expand All @@ -169,72 +167,7 @@ def build_job_partitions_from_assets(
f"'{second_asset_key}' have different partitions definitions. "
)

partitions_defs_by_asset_key: Dict[AssetKey, PartitionsDefinition] = {}
asset: Union[AssetsDefinition, SourceAsset]
for asset in itertools.chain.from_iterable([assets, source_assets]):
if isinstance(asset, AssetsDefinition) and asset.partitions_def is not None:
for asset_key in asset.asset_keys:
partitions_defs_by_asset_key[asset_key] = asset.partitions_def
elif isinstance(asset, SourceAsset) and asset.partitions_def is not None:
partitions_defs_by_asset_key[asset.key] = asset.partitions_def

def asset_partitions_for_job_partition(
job_partition_key: str,
) -> Mapping[AssetKey, PartitionKeyRange]:
return {
asset_key: PartitionKeyRange(job_partition_key, job_partition_key)
for assets_def in assets
for asset_key in assets_def.asset_keys
if assets_def.partitions_def
}

def run_config_for_partition_fn(partition_key: str) -> Dict[str, Any]:
ops_config: Dict[str, Any] = {}
asset_partitions_by_asset_key = asset_partitions_for_job_partition(partition_key)

for assets_def in assets:
outputs_dict: Dict[str, Dict[str, Any]] = {}
if assets_def.partitions_def is not None:
for output_name, asset_key in assets_def.asset_keys_by_output_name.items():
asset_partition_key_range = asset_partitions_by_asset_key[asset_key]
outputs_dict[output_name] = {
"start": asset_partition_key_range.start,
"end": asset_partition_key_range.end,
}

inputs_dict: Dict[str, Dict[str, Any]] = {}
for input_name, in_asset_key in assets_def.asset_keys_by_input_name.items():
upstream_partitions_def = partitions_defs_by_asset_key.get(in_asset_key)
if assets_def.partitions_def is not None and upstream_partitions_def is not None:
upstream_partition_key_range = get_upstream_partitions_for_partition_range(
assets_def, upstream_partitions_def, in_asset_key, asset_partition_key_range
)
inputs_dict[input_name] = {
"start": upstream_partition_key_range.start,
"end": upstream_partition_key_range.end,
}

config_schema = assets_def.node_def.config_schema
if (
config_schema
and isinstance(config_schema.config_type, Shape)
and "assets" in config_schema.config_type.fields
):
ops_config[assets_def.node_def.name] = {
"config": {
"assets": {
"input_partitions": inputs_dict,
"output_partitions": outputs_dict,
}
}
}

return {"ops": ops_config}

return PartitionedConfig(
partitions_def=cast(PartitionsDefinition, first_assets_with_partitions_def.partitions_def),
run_config_for_partition_fn=lambda p: run_config_for_partition_fn(p.name),
)
return first_assets_with_partitions_def.partitions_def


def build_source_assets_by_key(
Expand Down
31 changes: 30 additions & 1 deletion python_modules/dagster/dagster/core/definitions/asset_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dagster.utils.backcompat import ExperimentalWarning

from ..errors import DagsterInvalidSubsetError
from .config import ConfigMapping
from .dependency import NodeHandle, NodeInputHandle, NodeOutputHandle, SolidOutputHandle
from .executor_definition import ExecutorDefinition
from .graph_definition import GraphDefinition
Expand All @@ -37,7 +38,7 @@
from dagster.core.execution.context.output import OutputContext

from .job_definition import JobDefinition
from .partition import PartitionsDefinition
from .partition import PartitionedConfig, PartitionsDefinition


class AssetOutputInfo(
Expand Down Expand Up @@ -575,6 +576,9 @@ def source_assets_by_key(self) -> Mapping[AssetKey, "SourceAsset"]:
def assets_defs_by_key(self) -> Mapping[AssetKey, "AssetsDefinition"]:
return self._assets_defs_by_key

def assets_def_for_asset(self, asset_key: AssetKey) -> "AssetsDefinition":
return self._assets_defs_by_key[asset_key]

def asset_keys_for_node(self, node_handle: NodeHandle) -> AbstractSet[AssetKey]:
return self._asset_keys_by_node_handle[node_handle]

Expand Down Expand Up @@ -616,12 +620,26 @@ def group_names_by_assets(self) -> Mapping[AssetKey, str]:

return group_names

def partitions_def_for_asset(self, asset_key: AssetKey) -> Optional["PartitionsDefinition"]:
assets_def = self._assets_defs_by_key.get(asset_key)

if assets_def is not None:
return assets_def.partitions_def
else:
source_asset = self._source_assets_by_key.get(asset_key)
if source_asset is not None:
return source_asset.partitions_def

return None


def build_asset_selection_job(
name: str,
assets: Iterable["AssetsDefinition"],
source_assets: Iterable["SourceAsset"],
executor_def: Optional[ExecutorDefinition] = None,
config: Optional[Union[ConfigMapping, Dict[str, Any], "PartitionedConfig"]] = None,
partitions_def: Optional["PartitionsDefinition"] = None,
resource_defs: Optional[Mapping[str, ResourceDefinition]] = None,
description: Optional[str] = None,
tags: Optional[Dict[str, Any]] = None,
Expand All @@ -638,14 +656,25 @@ def build_asset_selection_job(
included_assets = cast(Iterable["AssetsDefinition"], assets)
excluded_assets = list(source_assets)

if partitions_def:
for asset in included_assets:
check.invariant(
asset.partitions_def == partitions_def,
f"Assets defined for node '{asset.node_def.name}' have a partitions_def of "
f"{asset.partitions_def}, but job '{name}' has non-matching partitions_def of "
f"{partitions_def}.",
)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ExperimentalWarning)
asset_job = build_assets_job(
name=name,
assets=included_assets,
config=config,
source_assets=excluded_assets,
resource_defs=resource_defs,
executor_def=executor_def,
partitions_def=partitions_def,
description=description,
tags=tags,
_asset_selection_data=asset_selection_data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
ScheduleEvaluationContext,
)
from .time_window_partitions import TimeWindow, TimeWindowPartitionsDefinition
from .unresolved_asset_job_definition import UnresolvedAssetJobDefinition


def build_schedule_from_partitioned_job(
job: JobDefinition,
job: Union[JobDefinition, UnresolvedAssetJobDefinition],
description: Optional[str] = None,
name: Optional[str] = None,
minute_of_hour: Optional[int] = None,
Expand All @@ -35,20 +36,24 @@ def build_schedule_from_partitioned_job(
The schedule executes at the cadence specified by the partitioning of the given job.
"""
check.invariant(len(job.mode_definitions) == 1, "job must only have one mode")
check.invariant(
job.mode_definitions[0].partitioned_config is not None, "job must be a partitioned job"
)
check.invariant(
not (day_of_week and day_of_month),
"Cannot provide both day_of_month and day_of_week parameter to build_schedule_from_partitioned_job.",
)
if isinstance(job, JobDefinition):
check.invariant(len(job.mode_definitions) == 1, "job must only have one mode")
check.invariant(
job.mode_definitions[0].partitioned_config is not None, "job must be a partitioned job"
)

partitioned_config = cast(PartitionedConfig, job.mode_definitions[0].partitioned_config)
partition_set = cast(PartitionSetDefinition, job.get_partition_set_def())
partitioned_config = cast(PartitionedConfig, job.mode_definitions[0].partitioned_config)
partition_set = cast(PartitionSetDefinition, job.get_partition_set_def())
partitions_def = cast(TimeWindowPartitionsDefinition, partitioned_config.partitions_def)
else:
partition_set = cast(PartitionSetDefinition, job.get_partition_set_def())
partitions_def = cast(TimeWindowPartitionsDefinition, job.partitions_def)

check.inst(partitioned_config.partitions_def, TimeWindowPartitionsDefinition)
partitions_def = cast(TimeWindowPartitionsDefinition, partitioned_config.partitions_def)
check.inst(partitions_def, TimeWindowPartitionsDefinition)

minute_of_hour = cast(
int,
Expand Down

0 comments on commit d684848

Please sign in to comment.