Skip to content

Commit

Permalink
allow AssetGroups to have assets with different partitions defs (#7388)
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed Apr 13, 2022
1 parent b09c8e1 commit 3b55c4e
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 88 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# pylint: disable=redefined-outer-name
from dagster import AssetGroup, DailyPartitionsDefinition, asset
from datetime import datetime

from dagster import AssetGroup, DailyPartitionsDefinition, HourlyPartitionsDefinition, asset

daily_partitions_def = DailyPartitionsDefinition(start_date="2020-01-01")

Expand All @@ -14,4 +16,14 @@ def downstream_daily_partitioned_asset(upstream_daily_partitioned_asset):
assert upstream_daily_partitioned_asset is None


@asset(partitions_def=HourlyPartitionsDefinition(start_date=datetime(2022, 3, 12, 0, 0)))
def hourly_partitioned_asset():
pass


@asset
def unpartitioned_asset():
pass


partitioned_asset_group = AssetGroup.from_current_module()
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ def test_software_defined_assets_job():


def test_partitioned_assets():
assert (
partitioned_asset_group.build_job("all_assets")
.execute_in_process(partition_key="2020-02-01")
.success
)
for job_def in partitioned_asset_group.get_base_jobs():
partition_key = job_def.mode_definitions[
0
].partitioned_config.partitions_def.get_partition_keys()[0]
assert job_def.execute_in_process(partition_key=partition_key).success
55 changes: 50 additions & 5 deletions python_modules/dagster/dagster/core/asset_defs/asset_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pkgutil
import re
import warnings
from collections import defaultdict
from importlib import import_module
from types import ModuleType
from typing import (
Expand All @@ -22,6 +23,8 @@

from dagster import check
from dagster.core.definitions.events import AssetKey
from dagster.core.definitions.executor_definition import in_process_executor
from dagster.core.errors import DagsterUnmetExecutorRequirementsError
from dagster.core.execution.execute_in_process_result import ExecuteInProcessResult
from dagster.core.storage.fs_asset_io_manager import fs_asset_io_manager
from dagster.utils import merge_dicts
Expand All @@ -30,12 +33,15 @@
from ..definitions.executor_definition import ExecutorDefinition
from ..definitions.job_definition import JobDefinition
from ..definitions.op_definition import OpDefinition
from ..definitions.partition import PartitionsDefinition
from ..definitions.resource_definition import ResourceDefinition
from ..errors import DagsterInvalidDefinitionError
from .assets import AssetsDefinition
from .assets_job import build_assets_job, build_root_manager, build_source_assets_by_key
from .source_asset import SourceAsset

ASSET_GROUP_BASE_JOB_PREFIX = "__ASSET_GROUP"


class AssetGroup(
NamedTuple(
Expand Down Expand Up @@ -149,9 +155,8 @@ def __new__(
)

@staticmethod
def all_assets_job_name() -> str:
"""The name of the mega-job that the provided list of assets is coerced into."""
return "__ASSET_GROUP"
def is_base_job_name(name) -> bool:
return name.startswith(ASSET_GROUP_BASE_JOB_PREFIX)

def build_job(
self,
Expand Down Expand Up @@ -508,9 +513,49 @@ def materialize(
name="in_process_materialization_job", selection=selection
).execute_in_process()

def get_base_jobs(self) -> Sequence[JobDefinition]:
"""For internal use only."""
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ExperimentalWarning)

assets_by_partitions_def: Dict[
Optional[PartitionsDefinition], List[AssetsDefinition]
] = defaultdict(list)
for assets_def in self.assets:
assets_by_partitions_def[assets_def.partitions_def].append(assets_def)

if len(assets_by_partitions_def.keys()) == 0 or assets_by_partitions_def.keys() == {
None
}:
return [
build_assets_job(
ASSET_GROUP_BASE_JOB_PREFIX,
assets=self.assets,
source_assets=self.source_assets,
resource_defs=self.resource_defs,
executor_def=self.executor_def,
)
]
else:
unpartitioned_assets = assets_by_partitions_def.get(None, [])
jobs = []

# sort to ensure some stability in the ordering
for i, (partitions_def, assets_with_partitions) in enumerate(
sorted(assets_by_partitions_def.items(), key=lambda item: repr(item[0]))
):
if partitions_def is not None:
jobs.append(
build_assets_job(
f"{ASSET_GROUP_BASE_JOB_PREFIX}_{i}",
assets=assets_with_partitions + unpartitioned_assets,
source_assets=[*self.source_assets, *self.assets],
resource_defs=self.resource_defs,
executor_def=self.executor_def,
)
)

from dagster.core.definitions.executor_definition import in_process_executor
from dagster.core.errors import DagsterUnmetExecutorRequirementsError
return jobs


def _find_assets_in_module(
Expand Down
3 changes: 3 additions & 0 deletions python_modules/dagster/dagster/core/definitions/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def get_partitions(
) -> List[Partition[str]]:
return self._partitions

def __hash__(self):
return hash(self.__repr__())

def __eq__(self, other) -> bool:
return (
isinstance(other, StaticPartitionsDefinition)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from abc import ABC, abstractmethod
from inspect import isfunction
from types import FunctionType
Expand All @@ -21,7 +20,6 @@
from dagster.core.asset_defs.source_asset import SourceAsset
from dagster.core.errors import DagsterInvalidDefinitionError, DagsterInvariantViolationError
from dagster.utils import merge_dicts
from dagster.utils.backcompat import ExperimentalWarning

from .events import AssetKey
from .graph_definition import GraphDefinition, SubselectedGraphDefinition
Expand Down Expand Up @@ -623,13 +621,14 @@ def from_list(
Use this constructor when you have no need to lazy load pipelines/jobs or other
definitions.
"""
from dagster.core.asset_defs import AssetGroup, build_assets_job
from dagster.core.asset_defs import AssetGroup

pipelines_or_jobs: Dict[str, Union[PipelineDefinition, JobDefinition]] = {}
partition_sets: Dict[str, PartitionSetDefinition] = {}
schedules: Dict[str, ScheduleDefinition] = {}
sensors: Dict[str, SensorDefinition] = {}
source_assets: Dict[AssetKey, SourceAsset] = {}
encountered_asset_group = False
for definition in repository_definitions:
if isinstance(definition, PipelineDefinition):
if (
Expand All @@ -641,9 +640,10 @@ def from_list(
target_type=definition.target_type, target=definition.describe_target()
)
)
if definition.name == AssetGroup.all_assets_job_name():
if AssetGroup.is_base_job_name(definition.name):
raise DagsterInvalidDefinitionError(
f"Attempted to provide job called {AssetGroup.all_assets_job_name()} to repository, which is a reserved name. Please rename the job."
f"Attempted to provide job called {definition.name} to repository, which "
"is a reserved name. Please rename the job."
)
pipelines_or_jobs[definition.name] = definition
elif isinstance(definition, PartitionSetDefinition):
Expand Down Expand Up @@ -694,21 +694,18 @@ def from_list(
pipelines_or_jobs[coerced.name] = coerced

elif isinstance(definition, AssetGroup):
asset_group = definition

if asset_group.all_assets_job_name() in pipelines_or_jobs:
if encountered_asset_group:
raise DagsterInvalidDefinitionError(
"When constructing repository, attempted to pass multiple AssetGroups. There can only be one AssetGroup per repository."
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ExperimentalWarning)
pipelines_or_jobs[asset_group.all_assets_job_name()] = build_assets_job(
asset_group.all_assets_job_name(),
assets=asset_group.assets,
source_assets=asset_group.source_assets,
resource_defs=asset_group.resource_defs,
executor_def=asset_group.executor_def,
"When constructing repository, attempted to pass multiple AssetGroups. "
"There can only be one AssetGroup per repository."
)

encountered_asset_group = True
asset_group = definition

for job_def in asset_group.get_base_jobs():
pipelines_or_jobs[job_def.name] = job_def

source_assets = {
source_asset.key: source_asset for source_asset in asset_group.source_assets
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from dagster import (
AssetKey,
DagsterInvalidDefinitionError,
DailyPartitionsDefinition,
HourlyPartitionsDefinition,
IOManager,
Out,
fs_asset_io_manager,
Expand Down Expand Up @@ -56,7 +58,7 @@ def the_repo():

assert len(the_repo.get_all_jobs()) == 1
asset_group_underlying_job = the_repo.get_all_jobs()[0]
assert asset_group_underlying_job.name == group.all_assets_job_name()
assert AssetGroup.is_base_job_name(asset_group_underlying_job.name)

result = asset_group_underlying_job.execute_in_process()
assert result.success
Expand Down Expand Up @@ -91,7 +93,7 @@ def the_repo():
return [group]

asset_group_underlying_job = the_repo.get_all_jobs()[0]
assert asset_group_underlying_job.name == group.all_assets_job_name()
assert AssetGroup.is_base_job_name(asset_group_underlying_job.name)

result = asset_group_underlying_job.execute_in_process()
assert result.success
Expand All @@ -113,7 +115,7 @@ def the_repo():
return [group]

asset_group_underlying_job = the_repo.get_all_jobs()[0]
assert asset_group_underlying_job.name == group.all_assets_job_name()
assert AssetGroup.is_base_job_name(asset_group_underlying_job.name)

result = asset_group_underlying_job.execute_in_process()
assert result.success
Expand Down Expand Up @@ -456,10 +458,10 @@ def test_job_with_reserved_name():
def the_graph():
pass

the_job = the_graph.to_job(name=AssetGroup.all_assets_job_name())
the_job = the_graph.to_job(name="__ASSET_GROUP")
with pytest.raises(
DagsterInvalidDefinitionError,
match=f"Attempted to provide job called {AssetGroup.all_assets_job_name()} to repository, which is a reserved name.",
match="Attempted to provide job called __ASSET_GROUP to repository, which is a reserved name.",
):

@repository
Expand Down Expand Up @@ -521,3 +523,50 @@ def follows_o2(o2):
assert result.output_for_node("middle_asset", "o1") == "foo"
assert result.output_for_node("follows_o2") == "foo"
assert result.output_for_node("start_asset") == "foo"


def test_multiple_partitions_defs():
@asset(partitions_def=DailyPartitionsDefinition(start_date="2021-05-05"))
def daily_asset():
...

@asset(partitions_def=DailyPartitionsDefinition(start_date="2021-05-05"))
def daily_asset2():
...

@asset(partitions_def=DailyPartitionsDefinition(start_date="2020-05-05"))
def daily_asset_different_start_date():
...

@asset(partitions_def=HourlyPartitionsDefinition(start_date="2021-05-05-00:00"))
def hourly_asset():
...

@asset
def unpartitioned_asset():
...

group = AssetGroup(
[
daily_asset,
daily_asset2,
daily_asset_different_start_date,
hourly_asset,
unpartitioned_asset,
]
)

jobs = group.get_base_jobs()
assert len(jobs) == 3
assert {job_def.name for job_def in jobs} == {
"__ASSET_GROUP_0",
"__ASSET_GROUP_1",
"__ASSET_GROUP_2",
}
assert {
frozenset([node_def.name for node_def in job_def.all_node_defs]) for job_def in jobs
} == {
frozenset(["daily_asset", "daily_asset2", "unpartitioned_asset"]),
frozenset(["hourly_asset", "unpartitioned_asset"]),
frozenset(["daily_asset_different_start_date", "unpartitioned_asset"]),
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from dagster import AssetGroup, AssetKey, DagsterInvariantViolationError, Out
from dagster import AssetKey, DagsterInvariantViolationError, Out
from dagster.check import CheckError
from dagster.core.asset_defs import AssetIn, SourceAsset, asset, build_assets_job, multi_asset
from dagster.core.definitions.metadata import MetadataEntry, MetadataValue
Expand Down Expand Up @@ -79,59 +79,6 @@ def asset2(asset1):
]


def test_two_asset_job_with_group():
@asset
def asset1():
return 1

@asset
def asset2(asset1):
assert asset1 == 1

asset_group = AssetGroup(assets=[asset1, asset2])
asset_group_job = build_assets_job(
asset_group.all_assets_job_name(),
assets=asset_group.assets,
source_assets=asset_group.source_assets,
resource_defs=asset_group.resource_defs,
executor_def=asset_group.executor_def,
)
assets_job = asset_group.build_job(name="assets_job")

external_asset_nodes = external_asset_graph_from_defs(
[asset_group_job, assets_job], source_assets_by_key={}
)

assert external_asset_nodes == [
ExternalAssetNode(
asset_key=AssetKey("asset1"),
dependencies=[],
depended_by=[
ExternalAssetDependedBy(
downstream_asset_key=AssetKey("asset2"), input_name="asset1"
)
],
op_name="asset1",
op_description=None,
job_names=[AssetGroup.all_assets_job_name(), "assets_job"],
output_name="result",
output_description=None,
),
ExternalAssetNode(
asset_key=AssetKey("asset2"),
dependencies=[
ExternalAssetDependency(upstream_asset_key=AssetKey("asset1"), input_name="asset1")
],
depended_by=[],
op_name="asset2",
op_description=None,
job_names=[AssetGroup.all_assets_job_name(), "assets_job"],
output_name="result",
output_description=None,
),
]


def test_input_name_matches_output_name():
not_result = SourceAsset(key=AssetKey("not_result"), description=None)

Expand Down

0 comments on commit 3b55c4e

Please sign in to comment.