Skip to content

Commit

Permalink
Definitions.merge
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed May 22, 2024
1 parent e0846dc commit 1681396
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 13 deletions.
150 changes: 137 additions & 13 deletions python_modules/dagster/dagster/_core/definitions/definitions_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,17 +252,7 @@ def _create_repository_using_definitions_args(
executor: Optional[Union[ExecutorDefinition, Executor]] = None,
loggers: Optional[Mapping[str, LoggerDefinition]] = None,
asset_checks: Optional[Iterable[AssetChecksDefinition]] = None,
):
check.opt_iterable_param(
assets, "assets", (AssetsDefinition, SourceAsset, CacheableAssetsDefinition)
)
check.opt_iterable_param(
schedules, "schedules", (ScheduleDefinition, UnresolvedPartitionedAssetScheduleDefinition)
)
check.opt_iterable_param(sensors, "sensors", SensorDefinition)
check.opt_iterable_param(jobs, "jobs", (JobDefinition, UnresolvedAssetJobDefinition))

check.opt_inst_param(executor, "executor", (ExecutorDefinition, Executor))
) -> RepositoryDefinition:
executor_def = (
executor
if isinstance(executor, ExecutorDefinition) or executor is None
Expand All @@ -285,8 +275,6 @@ def _create_repository_using_definitions_args(

resource_defs = wrap_resources_for_execution(resources_with_key_mapping)

check.opt_mapping_param(loggers, "loggers", key_type=str, value_type=LoggerDefinition)

# Binds top-level resources to jobs and any jobs attached to schedules or sensors
(
jobs_with_resources,
Expand Down Expand Up @@ -415,6 +403,15 @@ class Definitions:
Any other object is coerced to a :py:class:`ResourceDefinition`.
"""

_assets: Iterable[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]
_schedules: Iterable[Union[ScheduleDefinition, UnresolvedPartitionedAssetScheduleDefinition]]
_sensors: Iterable[SensorDefinition]
_jobs: Iterable[Union[JobDefinition, UnresolvedAssetJobDefinition]]
_resources: Mapping[str, Any]
_executor: Optional[Union[ExecutorDefinition, Executor]]
_loggers: Mapping[str, LoggerDefinition]
_asset_checks: Iterable[AssetChecksDefinition]

def __init__(
self,
assets: Optional[
Expand All @@ -430,6 +427,29 @@ def __init__(
loggers: Optional[Mapping[str, LoggerDefinition]] = None,
asset_checks: Optional[Iterable[AssetChecksDefinition]] = None,
):
self._assets = check.opt_iterable_param(
assets,
"assets",
(AssetsDefinition, SourceAsset, CacheableAssetsDefinition),
)
self._schedules = check.opt_iterable_param(
schedules,
"schedules",
(ScheduleDefinition, UnresolvedPartitionedAssetScheduleDefinition),
)
self._sensors = check.opt_iterable_param(sensors, "sensors", SensorDefinition)
self._jobs = check.opt_iterable_param(
jobs, "jobs", (JobDefinition, UnresolvedAssetJobDefinition)
)
self._asset_checks = check.opt_iterable_param(
asset_checks, "asset_checks", AssetChecksDefinition
)
self._resources = check.opt_mapping_param(resources, "resources", key_type=str)
self._executor = check.opt_inst_param(executor, "executor", (ExecutorDefinition, Executor))
self._loggers = check.opt_mapping_param(
loggers, "loggers", key_type=str, value_type=LoggerDefinition
)

self._created_pending_or_normal_repo = _create_repository_using_definitions_args(
name=SINGLETON_REPOSITORY_NAME,
assets=assets,
Expand All @@ -442,6 +462,40 @@ def __init__(
asset_checks=asset_checks,
)

@property
def assets(self) -> Iterable[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]:
return self._assets

@property
def schedules(
self,
) -> Iterable[Union[ScheduleDefinition, UnresolvedPartitionedAssetScheduleDefinition]]:
return self._schedules

@property
def sensors(self) -> Iterable[SensorDefinition]:
return self._sensors

@property
def jobs(self) -> Iterable[Union[JobDefinition, UnresolvedAssetJobDefinition]]:
return self._jobs

@property
def resources(self) -> Mapping[str, Any]:
return self._resources

@property
def executor(self) -> Optional[Union[ExecutorDefinition, Executor]]:
return self._executor

@property
def loggers(self) -> Mapping[str, LoggerDefinition]:
return self._loggers

@property
def asset_checks(self) -> Iterable[AssetChecksDefinition]:
return self._asset_checks

@public
def get_job_def(self, name: str) -> JobDefinition:
"""Get a job definition by name. If you passed in a an :py:class:`UnresolvedAssetJobDefinition`
Expand Down Expand Up @@ -573,3 +627,73 @@ def get_inner_repository_for_loading_process(
def get_asset_graph(self) -> AssetGraph:
"""Get the AssetGraph for this set of definitions."""
return self.get_repository_def().asset_graph

@staticmethod
def merge(*def_sets: "Definitions") -> "Definitions":
"""Merges multiple Definitions objects into a single Definitions object.
The returned Definitions object has the union of all the definitions in the input
Definitions objects.
Returns:
Definitions: The merged definitions.
"""
check.sequence_param(def_sets, "def_sets", of_type=Definitions)

assets = []
schedules = []
sensors = []
jobs = []
asset_checks = []

resources = {}
resource_key_indexes: Dict[str, int] = {}
loggers = {}
logger_key_indexes: Dict[str, int] = {}
executor = None
executor_index: Optional[int] = None

for i, def_set in enumerate(def_sets):
assets.extend(def_set.assets or [])
asset_checks.extend(def_set.asset_checks or [])
schedules.extend(def_set.schedules or [])
sensors.extend(def_set.sensors or [])
jobs.extend(def_set.jobs or [])

for resource_key, resource_value in (def_set.resources or {}).items():
if resource_key in resources:
raise DagsterInvariantViolationError(
f"Definitions objects {resource_key_indexes[resource_key]} and {i} both have a "
f"resource with key '{resource_key}'"
)
resources[resource_key] = resource_value
resource_key_indexes[resource_key] = i

for logger_key, logger_value in (def_set.loggers or {}).items():
if logger_key in loggers:
raise DagsterInvariantViolationError(
f"Definitions objects {logger_key_indexes[logger_key]} and {i} both have a "
f"logger with key '{logger_key}'"
)
loggers[logger_key] = logger_value
logger_key_indexes[logger_key] = i

if def_set.executor is not None:
if executor is not None and executor != def_set.executor:
raise DagsterInvariantViolationError(
f"Definitions objects {executor_index} and {i} both have an executor"
)

executor = def_set.executor
executor_index = i

return Definitions(
assets=assets,
schedules=schedules,
sensors=sensors,
jobs=jobs,
resources=resources,
executor=executor,
loggers=loggers,
asset_checks=asset_checks,
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
in_process_executor,
materialize,
mem_io_manager,
multiprocess_executor,
op,
repository,
sensor,
Expand Down Expand Up @@ -756,3 +757,100 @@ def c(b):
s = SourceAsset(key="s")
with pytest.raises(CircularDependencyError):
Definitions(assets=[a, b, c, s]).get_all_job_defs()


def test_merge():
@asset
def asset1(): ...

@asset
def asset2(): ...

@job
def job1(): ...

@job
def job2(): ...

schedule1 = ScheduleDefinition(name="schedule1", job=job1, cron_schedule="@daily")
schedule2 = ScheduleDefinition(name="schedule2", job=job2, cron_schedule="@daily")

@sensor(job=job1)
def sensor1(): ...

@sensor(job=job2)
def sensor2(): ...

resource1 = object()
resource2 = object()

@logger
def logger1(_):
raise Exception("not executed")

@logger
def logger2(_):
raise Exception("not executed")

defs1 = Definitions(
assets=[asset1],
jobs=[job1],
schedules=[schedule1],
sensors=[sensor1],
resources={"resource1": resource1},
loggers={"logger1": logger1},
executor=in_process_executor,
)
defs2 = Definitions(
assets=[asset2],
jobs=[job2],
schedules=[schedule2],
sensors=[sensor2],
resources={"resource2": resource2},
loggers={"logger2": logger2},
)

merged = Definitions.merge(defs1, defs2)
assert merged.assets == [asset1, asset2]
assert merged.jobs == [job1, job2]
assert merged.schedules == [schedule1, schedule2]
assert merged.sensors == [sensor1, sensor2]
assert merged.resources == {"resource1": resource1, "resource2": resource2}
assert merged.loggers == {"logger1": logger1, "logger2": logger2}
assert merged.executor == in_process_executor


def test_resource_conflict_on_merge():
defs1 = Definitions(resources={"resource1": 4})
defs2 = Definitions(resources={"resource1": 4})

with pytest.raises(
DagsterInvariantViolationError,
match="Definitions objects 0 and 1 both have a resource with key 'resource1'",
):
Definitions.merge(defs1, defs2)


def test_logger_conflict_on_merge():
@logger
def logger1(_):
raise Exception("not executed")

defs1 = Definitions(loggers={"logger1": logger1})
defs2 = Definitions(loggers={"logger1": logger1})

with pytest.raises(
DagsterInvariantViolationError,
match="Definitions objects 0 and 1 both have a logger with key 'logger1'",
):
Definitions.merge(defs1, defs2)


def test_executor_conflict_on_merge():
defs1 = Definitions(executor=in_process_executor)
defs2 = Definitions(executor=multiprocess_executor)

with pytest.raises(
DagsterInvariantViolationError, match="Definitions objects 0 and 1 both have an executor"
):
Definitions.merge(defs1, defs2)

0 comments on commit 1681396

Please sign in to comment.