From fa64bb7fbce4736506de1dd0822c7da56de6bf44 Mon Sep 17 00:00:00 2001 From: prha Date: Sat, 11 May 2024 14:48:52 -0700 Subject: [PATCH] add batch fetching of data version records --- .../execution/context/data_version_cache.py | 107 +++++++++++------- 1 file changed, 68 insertions(+), 39 deletions(-) diff --git a/python_modules/dagster/dagster/_core/execution/context/data_version_cache.py b/python_modules/dagster/dagster/_core/execution/context/data_version_cache.py index c404ebe7c092a..80fbfc6e9c1bf 100644 --- a/python_modules/dagster/dagster/_core/execution/context/data_version_cache.py +++ b/python_modules/dagster/dagster/_core/execution/context/data_version_cache.py @@ -76,7 +76,7 @@ def maybe_fetch_and_get_input_asset_version_info( self, key: AssetKey ) -> Optional["InputAssetVersionInfo"]: if key not in self.input_asset_version_info: - self._fetch_input_asset_version_info(key) + self._fetch_input_asset_version_info([key]) return self.input_asset_version_info[key] # "external" refers to records for inputs generated outside of this step @@ -93,54 +93,83 @@ def fetch_external_input_asset_version_info(self) -> None: all_dep_keys.append(key) self.input_asset_version_info = {} - for key in all_dep_keys: - self._fetch_input_asset_version_info(key) + self._fetch_input_asset_version_info(all_dep_keys) self.is_external_input_asset_version_info_loaded = True - def _fetch_input_asset_version_info(self, key: AssetKey) -> None: + def _fetch_input_asset_version_info(self, asset_keys: Sequence[AssetKey]) -> None: from dagster._core.definitions.data_version import ( extract_data_version_from_entry, ) - event = self._get_input_asset_event(key) - if event is None: - self.input_asset_version_info[key] = None + asset_records = self._context.instance.get_asset_records(asset_keys) + materialization_records_by_key: Dict[AssetKey, Optional[EventLogRecord]] = { + record.asset_entry.asset_key: record.asset_entry.last_materialization_record + for record in asset_records + } + + if self._context.instance.event_log_storage.asset_records_have_last_observation: + observations_records_by_key: Dict[AssetKey, Optional[EventLogRecord]] = { + record.asset_entry.asset_key: record.asset_entry.last_observation_record + for record in asset_records + } else: - storage_id = event.storage_id - # Input name will be none if this is an internal dep - input_name = self._context.job_def.asset_layer.input_for_asset_key( - self._context.node_handle, key - ) - # Exclude AllPartitionMapping for now to avoid huge queries - if input_name and self._context.has_asset_partitions_for_input(input_name): - subset = self._context.asset_partitions_subset_for_input( - input_name, require_valid_partitions=False + observations_records_by_key: Dict[AssetKey, Optional[EventLogRecord]] = {} + for key in asset_keys: + if key in materialization_records_by_key: + # we only need to fetch the last observation record if we did not have a materialization record + continue + + last_observation_record = next( + iter(self._context.instance.fetch_observations(key, limit=1).records), None + ) + if last_observation_record: + observations_records_by_key[key] = last_observation_record + + records_by_key: Dict[AssetKey, Optional[EventLogRecord]] = { + **observations_records_by_key, + **materialization_records_by_key, + } + + for key in asset_keys: + event = records_by_key.get(key) + if event is None: + self.input_asset_version_info[key] = None + else: + storage_id = event.storage_id + # Input name will be none if this is an internal dep + input_name = self._context.job_def.asset_layer.input_for_asset_key( + self._context.node_handle, key ) - input_keys = list(subset.get_partition_keys()) - - # This check represents a temporary constraint that prevents huge query results for upstream - # partition data versions from timing out runs. If a partitioned dependency (a) uses an - # AllPartitionMapping; and (b) has greater than or equal to - # SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD dependency partitions, then we - # process it as a non-partitioned dependency (note that this was the behavior for - # all partition dependencies prior to 2023-08). This means that stale status - # results cannot be accurately computed for the dependency, and there is thus - # corresponding logic in the CachingStaleStatusResolver to account for this. This - # constraint should be removed when we have thoroughly examined the performance of - # the data version retrieval query and can guarantee decent performance. - if len(input_keys) < SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD: - data_version = self._get_partitions_data_version_from_keys(key, input_keys) + # Exclude AllPartitionMapping for now to avoid huge queries + if input_name and self._context.has_asset_partitions_for_input(input_name): + subset = self._context.asset_partitions_subset_for_input( + input_name, require_valid_partitions=False + ) + input_keys = list(subset.get_partition_keys()) + + # This check represents a temporary constraint that prevents huge query results for upstream + # partition data versions from timing out runs. If a partitioned dependency (a) uses an + # AllPartitionMapping; and (b) has greater than or equal to + # SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD dependency partitions, then we + # process it as a non-partitioned dependency (note that this was the behavior for + # all partition dependencies prior to 2023-08). This means that stale status + # results cannot be accurately computed for the dependency, and there is thus + # corresponding logic in the CachingStaleStatusResolver to account for this. This + # constraint should be removed when we have thoroughly examined the performance of + # the data version retrieval query and can guarantee decent performance. + if len(input_keys) < SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD: + data_version = self._get_partitions_data_version_from_keys(key, input_keys) + else: + data_version = extract_data_version_from_entry(event.event_log_entry) else: data_version = extract_data_version_from_entry(event.event_log_entry) - else: - data_version = extract_data_version_from_entry(event.event_log_entry) - self.input_asset_version_info[key] = InputAssetVersionInfo( - storage_id, - check.not_none(event.event_log_entry.dagster_event).event_type, - data_version, - event.run_id, - event.timestamp, - ) + self.input_asset_version_info[key] = InputAssetVersionInfo( + storage_id, + check.not_none(event.event_log_entry.dagster_event).event_type, + data_version, + event.run_id, + event.timestamp, + ) def _get_input_asset_event(self, key: AssetKey) -> Optional["EventLogRecord"]: event = self._context.instance.get_latest_data_version_record(key)