Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add batch fetching of data version records #21798

Merged
merged 6 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import TYPE_CHECKING

"""This module contains the execution context objects that are internal to the system.
Expand Down Expand Up @@ -29,6 +30,7 @@
)
from dagster._core.event_api import EventLogRecord
from dagster._core.events import DagsterEventType
from dagster._core.storage.event_log.base import AssetRecord


if TYPE_CHECKING:
Expand Down Expand Up @@ -76,7 +78,7 @@ def maybe_fetch_and_get_input_asset_version_info(
self, key: AssetKey
) -> Optional["InputAssetVersionInfo"]:
if key not in self.input_asset_version_info:
self._fetch_input_asset_version_info(key)
self._fetch_input_asset_version_info([key])
return self.input_asset_version_info[key]

# "external" refers to records for inputs generated outside of this step
Expand All @@ -93,57 +95,88 @@ def fetch_external_input_asset_version_info(self) -> None:
all_dep_keys.append(key)

self.input_asset_version_info = {}
for key in all_dep_keys:
self._fetch_input_asset_version_info(key)
self._fetch_input_asset_version_info(all_dep_keys)
prha marked this conversation as resolved.
Show resolved Hide resolved
self.is_external_input_asset_version_info_loaded = True

def _fetch_input_asset_version_info(self, key: AssetKey) -> None:
def _fetch_input_asset_version_info(self, asset_keys: Sequence[AssetKey]) -> None:
from dagster._core.definitions.data_version import (
extract_data_version_from_entry,
)

event = self._get_input_asset_event(key)
if event is None:
self.input_asset_version_info[key] = None
else:
storage_id = event.storage_id
# Input name will be none if this is an internal dep
input_name = self._context.job_def.asset_layer.input_for_asset_key(
self._context.node_handle, key
)
# Exclude AllPartitionMapping for now to avoid huge queries
if input_name and self._context.has_asset_partitions_for_input(input_name):
subset = self._context.asset_partitions_subset_for_input(
input_name, require_valid_partitions=False
asset_records_by_key = self._fetch_asset_records(asset_keys)
for key in asset_keys:
asset_record = asset_records_by_key.get(key)
event = self._get_input_asset_event(key, asset_record)
if event is None:
self.input_asset_version_info[key] = None
else:
storage_id = event.storage_id
# Input name will be none if this is an internal dep
input_name = self._context.job_def.asset_layer.input_for_asset_key(
self._context.node_handle, key
)
input_keys = list(subset.get_partition_keys())

# This check represents a temporary constraint that prevents huge query results for upstream
# partition data versions from timing out runs. If a partitioned dependency (a) uses an
# AllPartitionMapping; and (b) has greater than or equal to
# SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD dependency partitions, then we
# process it as a non-partitioned dependency (note that this was the behavior for
# all partition dependencies prior to 2023-08). This means that stale status
# results cannot be accurately computed for the dependency, and there is thus
# corresponding logic in the CachingStaleStatusResolver to account for this. This
# constraint should be removed when we have thoroughly examined the performance of
# the data version retrieval query and can guarantee decent performance.
if len(input_keys) < SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD:
data_version = self._get_partitions_data_version_from_keys(key, input_keys)
# Exclude AllPartitionMapping for now to avoid huge queries
if input_name and self._context.has_asset_partitions_for_input(input_name):
subset = self._context.asset_partitions_subset_for_input(
input_name, require_valid_partitions=False
)
input_keys = list(subset.get_partition_keys())

# This check represents a temporary constraint that prevents huge query results for upstream
# partition data versions from timing out runs. If a partitioned dependency (a) uses an
# AllPartitionMapping; and (b) has greater than or equal to
# SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD dependency partitions, then we
# process it as a non-partitioned dependency (note that this was the behavior for
# all partition dependencies prior to 2023-08). This means that stale status
# results cannot be accurately computed for the dependency, and there is thus
# corresponding logic in the CachingStaleStatusResolver to account for this. This
# constraint should be removed when we have thoroughly examined the performance of
# the data version retrieval query and can guarantee decent performance.
if len(input_keys) < SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD:
data_version = self._get_partitions_data_version_from_keys(key, input_keys)
else:
data_version = extract_data_version_from_entry(event.event_log_entry)
else:
data_version = extract_data_version_from_entry(event.event_log_entry)
else:
data_version = extract_data_version_from_entry(event.event_log_entry)
self.input_asset_version_info[key] = InputAssetVersionInfo(
storage_id,
check.not_none(event.event_log_entry.dagster_event).event_type,
data_version,
event.run_id,
event.timestamp,
self.input_asset_version_info[key] = InputAssetVersionInfo(
storage_id,
check.not_none(event.event_log_entry.dagster_event).event_type,
data_version,
event.run_id,
event.timestamp,
)

def _fetch_asset_records(self, asset_keys: Sequence[AssetKey]) -> Dict[AssetKey, "AssetRecord"]:
batch_size = int(os.getenv("GET_ASSET_RECORDS_FOR_DATA_VERSION_BATCH_SIZE", "100"))
asset_records_by_key = {}
to_fetch = asset_keys
while len(to_fetch):
for record in self._context.instance.get_asset_records(to_fetch[:batch_size]):
asset_records_by_key[record.asset_entry.asset_key] = record
to_fetch = to_fetch[batch_size:]

return asset_records_by_key

def _get_input_asset_event(
self, key: AssetKey, asset_record: Optional["AssetRecord"]
) -> Optional["EventLogRecord"]:
event = None
if asset_record and asset_record.asset_entry.last_materialization_record:
event = asset_record.asset_entry.last_materialization_record
elif (
asset_record
and self._context.instance.event_log_storage.asset_records_have_last_observation
):
event = asset_record.asset_entry.last_observation_record

if (
not event
and not self._context.instance.event_log_storage.asset_records_have_last_observation
):
event = next(
iter(self._context.instance.fetch_observations(key, limit=1).records), None
)

def _get_input_asset_event(self, key: AssetKey) -> Optional["EventLogRecord"]:
event = self._context.instance.get_latest_data_version_record(key)
if event:
self._check_input_asset_event(key, event)
return event
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

import pytest
from dagster import (
AssetIn,
AssetMaterialization,
AssetOut,
DagsterInstance,
MaterializeResult,
RunConfig,
Expand All @@ -16,8 +18,6 @@
)
from dagster._config.field import Field
from dagster._config.pythonic_config import Config
from dagster._core.definitions.asset_in import AssetIn
from dagster._core.definitions.asset_out import AssetOut
from dagster._core.definitions.data_version import (
DATA_VERSION_TAG,
SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD,
Expand All @@ -44,6 +44,7 @@
ASSET_PARTITION_RANGE_END_TAG,
ASSET_PARTITION_RANGE_START_TAG,
)
from dagster._utils import Counter, traced_counter
from dagster._utils.test.data_versions import (
assert_code_version,
assert_data_version,
Expand Down Expand Up @@ -1175,3 +1176,33 @@ def asset1():
assert extract_data_provenance_from_entry(record.event_log_entry).input_storage_ids == {
AssetKey(["asset0"]): 500
}


def test_fan_in():
def create_upstream_asset(i: int):
@asset(name=f"upstream_asset_{i}", code_version="abc")
def upstream_asset():
return i

return upstream_asset

upstream_assets = [create_upstream_asset(i) for i in range(100)]

@asset(
ins={f"input_{i}": AssetIn(key=f"upstream_asset_{i}") for i in range(100)},
code_version="abc",
)
def downstream_asset(**kwargs):
return kwargs.values()

all_assets = [*upstream_assets, downstream_asset]
instance = DagsterInstance.ephemeral()
materialize_assets(all_assets, instance)

counter = Counter()
traced_counter.set(counter)
materialize_assets(all_assets, instance)[downstream_asset.key]
assert traced_counter.get().counts() == {
"DagsterInstance.get_asset_records": 1,
"DagsterInstance.get_run_record_by_id": 1,
}