From 16813967cf2c78a06bbac8dd68a63ba253112d1c Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Wed, 22 May 2024 11:35:04 -0700 Subject: [PATCH] Definitions.merge --- .../_core/definitions/definitions_class.py | 150 ++++++++++++++++-- .../test_definitions_class.py | 98 ++++++++++++ 2 files changed, 235 insertions(+), 13 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/definitions_class.py b/python_modules/dagster/dagster/_core/definitions/definitions_class.py index b3a2d8da856c..27231b65dfec 100644 --- a/python_modules/dagster/dagster/_core/definitions/definitions_class.py +++ b/python_modules/dagster/dagster/_core/definitions/definitions_class.py @@ -252,17 +252,7 @@ def _create_repository_using_definitions_args( executor: Optional[Union[ExecutorDefinition, Executor]] = None, loggers: Optional[Mapping[str, LoggerDefinition]] = None, asset_checks: Optional[Iterable[AssetChecksDefinition]] = None, -): - check.opt_iterable_param( - assets, "assets", (AssetsDefinition, SourceAsset, CacheableAssetsDefinition) - ) - check.opt_iterable_param( - schedules, "schedules", (ScheduleDefinition, UnresolvedPartitionedAssetScheduleDefinition) - ) - check.opt_iterable_param(sensors, "sensors", SensorDefinition) - check.opt_iterable_param(jobs, "jobs", (JobDefinition, UnresolvedAssetJobDefinition)) - - check.opt_inst_param(executor, "executor", (ExecutorDefinition, Executor)) +) -> RepositoryDefinition: executor_def = ( executor if isinstance(executor, ExecutorDefinition) or executor is None @@ -285,8 +275,6 @@ def _create_repository_using_definitions_args( resource_defs = wrap_resources_for_execution(resources_with_key_mapping) - check.opt_mapping_param(loggers, "loggers", key_type=str, value_type=LoggerDefinition) - # Binds top-level resources to jobs and any jobs attached to schedules or sensors ( jobs_with_resources, @@ -415,6 +403,15 @@ class Definitions: Any other object is coerced to a :py:class:`ResourceDefinition`. """ + _assets: Iterable[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]] + _schedules: Iterable[Union[ScheduleDefinition, UnresolvedPartitionedAssetScheduleDefinition]] + _sensors: Iterable[SensorDefinition] + _jobs: Iterable[Union[JobDefinition, UnresolvedAssetJobDefinition]] + _resources: Mapping[str, Any] + _executor: Optional[Union[ExecutorDefinition, Executor]] + _loggers: Mapping[str, LoggerDefinition] + _asset_checks: Iterable[AssetChecksDefinition] + def __init__( self, assets: Optional[ @@ -430,6 +427,29 @@ def __init__( loggers: Optional[Mapping[str, LoggerDefinition]] = None, asset_checks: Optional[Iterable[AssetChecksDefinition]] = None, ): + self._assets = check.opt_iterable_param( + assets, + "assets", + (AssetsDefinition, SourceAsset, CacheableAssetsDefinition), + ) + self._schedules = check.opt_iterable_param( + schedules, + "schedules", + (ScheduleDefinition, UnresolvedPartitionedAssetScheduleDefinition), + ) + self._sensors = check.opt_iterable_param(sensors, "sensors", SensorDefinition) + self._jobs = check.opt_iterable_param( + jobs, "jobs", (JobDefinition, UnresolvedAssetJobDefinition) + ) + self._asset_checks = check.opt_iterable_param( + asset_checks, "asset_checks", AssetChecksDefinition + ) + self._resources = check.opt_mapping_param(resources, "resources", key_type=str) + self._executor = check.opt_inst_param(executor, "executor", (ExecutorDefinition, Executor)) + self._loggers = check.opt_mapping_param( + loggers, "loggers", key_type=str, value_type=LoggerDefinition + ) + self._created_pending_or_normal_repo = _create_repository_using_definitions_args( name=SINGLETON_REPOSITORY_NAME, assets=assets, @@ -442,6 +462,40 @@ def __init__( asset_checks=asset_checks, ) + @property + def assets(self) -> Iterable[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]: + return self._assets + + @property + def schedules( + self, + ) -> Iterable[Union[ScheduleDefinition, UnresolvedPartitionedAssetScheduleDefinition]]: + return self._schedules + + @property + def sensors(self) -> Iterable[SensorDefinition]: + return self._sensors + + @property + def jobs(self) -> Iterable[Union[JobDefinition, UnresolvedAssetJobDefinition]]: + return self._jobs + + @property + def resources(self) -> Mapping[str, Any]: + return self._resources + + @property + def executor(self) -> Optional[Union[ExecutorDefinition, Executor]]: + return self._executor + + @property + def loggers(self) -> Mapping[str, LoggerDefinition]: + return self._loggers + + @property + def asset_checks(self) -> Iterable[AssetChecksDefinition]: + return self._asset_checks + @public def get_job_def(self, name: str) -> JobDefinition: """Get a job definition by name. If you passed in a an :py:class:`UnresolvedAssetJobDefinition` @@ -573,3 +627,73 @@ def get_inner_repository_for_loading_process( def get_asset_graph(self) -> AssetGraph: """Get the AssetGraph for this set of definitions.""" return self.get_repository_def().asset_graph + + @staticmethod + def merge(*def_sets: "Definitions") -> "Definitions": + """Merges multiple Definitions objects into a single Definitions object. + + The returned Definitions object has the union of all the definitions in the input + Definitions objects. + + Returns: + Definitions: The merged definitions. + """ + check.sequence_param(def_sets, "def_sets", of_type=Definitions) + + assets = [] + schedules = [] + sensors = [] + jobs = [] + asset_checks = [] + + resources = {} + resource_key_indexes: Dict[str, int] = {} + loggers = {} + logger_key_indexes: Dict[str, int] = {} + executor = None + executor_index: Optional[int] = None + + for i, def_set in enumerate(def_sets): + assets.extend(def_set.assets or []) + asset_checks.extend(def_set.asset_checks or []) + schedules.extend(def_set.schedules or []) + sensors.extend(def_set.sensors or []) + jobs.extend(def_set.jobs or []) + + for resource_key, resource_value in (def_set.resources or {}).items(): + if resource_key in resources: + raise DagsterInvariantViolationError( + f"Definitions objects {resource_key_indexes[resource_key]} and {i} both have a " + f"resource with key '{resource_key}'" + ) + resources[resource_key] = resource_value + resource_key_indexes[resource_key] = i + + for logger_key, logger_value in (def_set.loggers or {}).items(): + if logger_key in loggers: + raise DagsterInvariantViolationError( + f"Definitions objects {logger_key_indexes[logger_key]} and {i} both have a " + f"logger with key '{logger_key}'" + ) + loggers[logger_key] = logger_value + logger_key_indexes[logger_key] = i + + if def_set.executor is not None: + if executor is not None and executor != def_set.executor: + raise DagsterInvariantViolationError( + f"Definitions objects {executor_index} and {i} both have an executor" + ) + + executor = def_set.executor + executor_index = i + + return Definitions( + assets=assets, + schedules=schedules, + sensors=sensors, + jobs=jobs, + resources=resources, + executor=executor, + loggers=loggers, + asset_checks=asset_checks, + ) diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_definitions_class.py b/python_modules/dagster/dagster_tests/definitions_tests/test_definitions_class.py index 4b3097ab5171..790cf5b3c729 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_definitions_class.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_definitions_class.py @@ -18,6 +18,7 @@ in_process_executor, materialize, mem_io_manager, + multiprocess_executor, op, repository, sensor, @@ -756,3 +757,100 @@ def c(b): s = SourceAsset(key="s") with pytest.raises(CircularDependencyError): Definitions(assets=[a, b, c, s]).get_all_job_defs() + + +def test_merge(): + @asset + def asset1(): ... + + @asset + def asset2(): ... + + @job + def job1(): ... + + @job + def job2(): ... + + schedule1 = ScheduleDefinition(name="schedule1", job=job1, cron_schedule="@daily") + schedule2 = ScheduleDefinition(name="schedule2", job=job2, cron_schedule="@daily") + + @sensor(job=job1) + def sensor1(): ... + + @sensor(job=job2) + def sensor2(): ... + + resource1 = object() + resource2 = object() + + @logger + def logger1(_): + raise Exception("not executed") + + @logger + def logger2(_): + raise Exception("not executed") + + defs1 = Definitions( + assets=[asset1], + jobs=[job1], + schedules=[schedule1], + sensors=[sensor1], + resources={"resource1": resource1}, + loggers={"logger1": logger1}, + executor=in_process_executor, + ) + defs2 = Definitions( + assets=[asset2], + jobs=[job2], + schedules=[schedule2], + sensors=[sensor2], + resources={"resource2": resource2}, + loggers={"logger2": logger2}, + ) + + merged = Definitions.merge(defs1, defs2) + assert merged.assets == [asset1, asset2] + assert merged.jobs == [job1, job2] + assert merged.schedules == [schedule1, schedule2] + assert merged.sensors == [sensor1, sensor2] + assert merged.resources == {"resource1": resource1, "resource2": resource2} + assert merged.loggers == {"logger1": logger1, "logger2": logger2} + assert merged.executor == in_process_executor + + +def test_resource_conflict_on_merge(): + defs1 = Definitions(resources={"resource1": 4}) + defs2 = Definitions(resources={"resource1": 4}) + + with pytest.raises( + DagsterInvariantViolationError, + match="Definitions objects 0 and 1 both have a resource with key 'resource1'", + ): + Definitions.merge(defs1, defs2) + + +def test_logger_conflict_on_merge(): + @logger + def logger1(_): + raise Exception("not executed") + + defs1 = Definitions(loggers={"logger1": logger1}) + defs2 = Definitions(loggers={"logger1": logger1}) + + with pytest.raises( + DagsterInvariantViolationError, + match="Definitions objects 0 and 1 both have a logger with key 'logger1'", + ): + Definitions.merge(defs1, defs2) + + +def test_executor_conflict_on_merge(): + defs1 = Definitions(executor=in_process_executor) + defs2 = Definitions(executor=multiprocess_executor) + + with pytest.raises( + DagsterInvariantViolationError, match="Definitions objects 0 and 1 both have an executor" + ): + Definitions.merge(defs1, defs2)