Skip to content

Commit

Permalink
RFC: Fetch Cross-Repo Asset Deps (#7259)
Browse files Browse the repository at this point in the history
  • Loading branch information
clairelin135 committed Apr 13, 2022
1 parent 90ac1b5 commit a5c2854
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import TYPE_CHECKING, Dict, Mapping

from dagster_graphql.implementation.loader import CrossRepoAssetDependedByLoader

from dagster import AssetKey, DagsterEventType, EventRecordsFilter, check, seven
from dagster.core.events import ASSET_EVENTS

Expand Down Expand Up @@ -64,14 +66,19 @@ def get_asset_nodes_by_asset_key(graphene_info) -> Mapping[AssetKey, "GrapheneAs

from ..schema.asset_graph import GrapheneAssetNode

depended_by_loader = CrossRepoAssetDependedByLoader(context=graphene_info.context)

asset_nodes_by_asset_key: Dict[AssetKey, GrapheneAssetNode] = {}
for location in graphene_info.context.repository_locations:
for repository in location.get_repositories().values():
for external_asset_node in repository.get_external_asset_nodes():
preexisting_node = asset_nodes_by_asset_key.get(external_asset_node.asset_key)
if preexisting_node is None or preexisting_node.external_asset_node.op_name is None:
asset_nodes_by_asset_key[external_asset_node.asset_key] = GrapheneAssetNode(
location, repository, external_asset_node
location,
repository,
external_asset_node,
depended_by_loader=depended_by_loader,
)

return asset_nodes_by_asset_key
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
from collections import defaultdict
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Set
from functools import lru_cache
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

from dagster import DagsterInstance, check
from dagster.core.definitions.events import AssetKey
from dagster.core.events.log import EventLogEntry
from dagster.core.host_representation import ExternalRepository
from dagster.core.host_representation.external_data import (
ExternalAssetDependedBy,
ExternalAssetDependency,
ExternalAssetNode,
)
from dagster.core.scheduler.instigation import InstigatorType
from dagster.core.storage.pipeline_run import JobBucket, RunRecord, RunsFilter, TagBucket
from dagster.core.storage.tags import SCHEDULE_NAME_TAG, SENSOR_NAME_TAG
from dagster.core.workspace.context import WorkspaceRequestContext


class RepositoryDataType(Enum):
Expand Down Expand Up @@ -285,3 +292,114 @@ def get_latest_materialization_for_asset_key(self, asset_key: AssetKey) -> Event
def _fetch(self):
self._fetched = True
self._materializations = self._instance.get_latest_materialization_events(self._asset_keys)


class CrossRepoAssetDependedByLoader:
"""
A batch loader that computes cross-repository asset dependencies. Locates source assets
within all workspace repositories, and determines if they are derived (defined) assets in
other repositories.
For each asset that contains cross-repo dependencies (every asset that is defined as a source
asset in another repository) a sink asset is any asset immediately downstream of the source
asset.
E.g. Asset A is defined in repo X and referenced in repo Y as source asset C (but contains the
same asset key as A). If within repo C has a downstream asset B, B is a sink asset of A (it
is external from A's repo but an edge exists from A to B).
The @lru_cache decorator enables the _build_cross_repo_deps method to cache its return value
to avoid recalculating the asset dependencies on repeated calls to the method.
"""

def __init__(self, context: WorkspaceRequestContext):
self._context = context

@lru_cache(maxsize=1)
def _build_cross_repo_deps(
self,
) -> Tuple[
Dict[AssetKey, ExternalAssetNode],
Dict[Tuple[str, str], Dict[AssetKey, List[ExternalAssetDependedBy]]],
]:
"""
This method constructs a sink asset as an ExternalAssetNode for every asset immediately
downstream of a source asset that is defined in another repository as a derived asset.
In Dagit, sink assets will display as ForeignAssets, which are external from the repository.
This method also stores a mapping from source asset key to ExternalAssetDependedBy nodes
that depend on the asset with that key. When get_cross_repo_dependent_assets is called with a derived
asset's asset key and its location, all dependent ExternalAssetDependedBy nodes are returned.
"""
depended_by_assets_by_source_asset: Dict[AssetKey, List[ExternalAssetDependedBy]] = {}

map_defined_asset_to_location: Dict[
AssetKey, Tuple[str, str]
] = {} # key is asset key, value is tuple (location_name, repo_name)

external_asset_node_by_asset_key: Dict[
AssetKey, ExternalAssetNode
] = {} # only contains derived assets
for location in self._context.repository_locations:
repositories = location.get_repositories()
for repo_name, external_repo in repositories.items():
asset_nodes = external_repo.get_external_asset_nodes()
for asset_node in asset_nodes:
if not asset_node.op_name: # is source asset
if asset_node.asset_key not in depended_by_assets_by_source_asset:
depended_by_assets_by_source_asset[asset_node.asset_key] = []
depended_by_assets_by_source_asset[asset_node.asset_key].extend(
asset_node.depended_by
)
else:
map_defined_asset_to_location[asset_node.asset_key] = (
location.name,
repo_name,
)
external_asset_node_by_asset_key[asset_node.asset_key] = asset_node

sink_assets: Dict[AssetKey, ExternalAssetNode] = {}
external_asset_deps: Dict[
Tuple[str, str], Dict[AssetKey, List[ExternalAssetDependedBy]]
] = (
{}
) # nested dict that maps dependedby assets by asset key by location tuple (repo_location.name, repo_name)

for source_asset, depended_by_assets in depended_by_assets_by_source_asset.items():
asset_def_location = map_defined_asset_to_location.get(source_asset, None)
if asset_def_location: # source asset is defined as asset in another repository
if asset_def_location not in external_asset_deps:
external_asset_deps[asset_def_location] = {}
if source_asset not in external_asset_deps[asset_def_location]:
external_asset_deps[asset_def_location][source_asset] = []
external_asset_deps[asset_def_location][source_asset].extend(depended_by_assets)
for asset in depended_by_assets:
# SourceAssets defined as ExternalAssetNodes contain no definition data (e.g.
# no output or partition definition data) and no job_names. Dagit displays
# all ExternalAssetNodes with no job_names as foreign assets, so sink assets
# are defined as ExternalAssetNodes with no definition data.
sink_assets[asset.downstream_asset_key] = ExternalAssetNode(
asset_key=asset.downstream_asset_key,
dependencies=[
ExternalAssetDependency(
upstream_asset_key=source_asset,
input_name=asset.input_name,
output_name=asset.output_name,
)
],
depended_by=[],
)
return sink_assets, external_asset_deps

def get_sink_asset(self, asset_key: AssetKey) -> ExternalAssetNode:
sink_assets, _ = self._build_cross_repo_deps()
return sink_assets.get(asset_key)

def get_cross_repo_dependent_assets(
self, repository_location_name: str, repository_name: str, asset_key: AssetKey
) -> List[ExternalAssetDependedBy]:
_, external_asset_deps = self._build_cross_repo_deps()
return external_asset_deps.get((repository_location_name, repository_name), {}).get(
asset_key, []
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ExternalTimeWindowPartitionsDefinitionData,
)

from ..implementation.loader import BatchMaterializationLoader
from ..implementation.loader import BatchMaterializationLoader, CrossRepoAssetDependedByLoader
from . import external
from .asset_key import GrapheneAssetKey
from .errors import GrapheneAssetNotFoundError
Expand All @@ -43,6 +43,7 @@ def __init__(
input_name: str,
asset_key: AssetKey,
materialization_loader: Optional[BatchMaterializationLoader] = None,
depended_by_loader: Optional[CrossRepoAssetDependedByLoader] = None,
):
self._repository_location = check.inst_param(
repository_location, "repository_location", RepositoryLocation
Expand All @@ -54,12 +55,17 @@ def __init__(
self._latest_materialization_loader = check.opt_inst_param(
materialization_loader, "materialization_loader", BatchMaterializationLoader
)
self._depended_by_loader = check.opt_inst_param(
depended_by_loader, "depended_by_loader", CrossRepoAssetDependedByLoader
)
super().__init__(inputName=input_name)

def resolve_asset(self, _graphene_info):
asset_node = check.not_none(
self._external_repository.get_external_asset_node(self._asset_key)
)
asset_node = self._external_repository.get_external_asset_node(self._asset_key)
if not asset_node and self._depended_by_loader:
# Only load from dependency loader if asset node cannot be found in current repository
asset_node = self._depended_by_loader.get_sink_asset(self._asset_key)
asset_node = check.not_none(asset_node)
return GrapheneAssetNode(
self._repository_location,
self._external_repository,
Expand Down Expand Up @@ -108,6 +114,7 @@ def __init__(
external_repository: ExternalRepository,
external_asset_node: ExternalAssetNode,
materialization_loader: Optional[BatchMaterializationLoader] = None,
depended_by_loader: Optional[CrossRepoAssetDependedByLoader] = None,
):
self._repository_location = check.inst_param(
repository_location,
Expand All @@ -123,6 +130,9 @@ def __init__(
self._latest_materialization_loader = check.opt_inst_param(
materialization_loader, "materialization_loader", BatchMaterializationLoader
)
self._depended_by_loader = check.opt_inst_param(
depended_by_loader, "depended_by_loader", CrossRepoAssetDependedByLoader
)

super().__init__(
id=external_asset_node.asset_key.to_string(),
Expand Down Expand Up @@ -203,12 +213,27 @@ def resolve_computeKind(self, _graphene_info) -> Optional[str]:
return self._external_asset_node.compute_kind

def resolve_dependedBy(self, graphene_info) -> List[GrapheneAssetDependency]:
if not self._external_asset_node.depended_by:
# CrossRepoAssetDependedByLoader class loads cross-repo asset dependencies workspace-wide.
# In order to avoid recomputing workspace-wide values per asset node, we add a loader
# that batch loads all cross-repo dependencies for the whole workspace.
check.invariant(
self._depended_by_loader,
"depended_by_loader must exist in order to resolve dependedBy nodes",
)

depended_by_asset_nodes = self._depended_by_loader.get_cross_repo_dependent_assets(
self._repository_location.name,
self._external_repository.name,
self._external_asset_node.asset_key,
)
depended_by_asset_nodes.extend(self._external_asset_node.depended_by)

if not depended_by_asset_nodes:
return []

materialization_loader = BatchMaterializationLoader(
instance=graphene_info.context.instance,
asset_keys=[dep.downstream_asset_key for dep in self._external_asset_node.depended_by],
asset_keys=[dep.downstream_asset_key for dep in depended_by_asset_nodes],
)

return [
Expand All @@ -218,14 +243,29 @@ def resolve_dependedBy(self, graphene_info) -> List[GrapheneAssetDependency]:
input_name=dep.input_name,
asset_key=dep.downstream_asset_key,
materialization_loader=materialization_loader,
depended_by_loader=self._depended_by_loader,
)
for dep in self._external_asset_node.depended_by
for dep in depended_by_asset_nodes
]

def resolve_dependedByKeys(self, _graphene_info) -> List[GrapheneAssetKey]:
# CrossRepoAssetDependedByLoader class loads all cross-repo asset dependencies workspace-wide.
# In order to avoid recomputing workspace-wide values per asset node, we add a loader
# that batch loads all cross-repo dependencies for the whole workspace.
check.invariant(
self._depended_by_loader,
"depended_by_loader must exist in order to resolve dependedBy nodes",
)

depended_by_asset_nodes = self._depended_by_loader.get_cross_repo_dependent_assets(
self._repository_location.name,
self._external_repository.name,
self._external_asset_node.asset_key,
)
depended_by_asset_nodes.extend(self._external_asset_node.depended_by)

return [
GrapheneAssetKey(path=dep.downstream_asset_key.path)
for dep in self._external_asset_node.depended_by
GrapheneAssetKey(path=dep.downstream_asset_key.path) for dep in depended_by_asset_nodes
]

def resolve_dependencyKeys(self, _graphene_info):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)
from ...implementation.fetch_sensors import get_sensor_or_error, get_sensors_or_error
from ...implementation.fetch_solids import get_graph_or_error
from ...implementation.loader import BatchMaterializationLoader
from ...implementation.loader import BatchMaterializationLoader, CrossRepoAssetDependedByLoader
from ...implementation.run_config_schema import resolve_run_config_schema_or_error
from ...implementation.utils import graph_selector_from_graphql, pipeline_selector_from_graphql
from ..asset_graph import GrapheneAssetNode, GrapheneAssetNodeOrError
Expand Down Expand Up @@ -471,12 +471,15 @@ def resolve_assetNodes(self, graphene_info, **kwargs):
materialization_loader = BatchMaterializationLoader(
instance=graphene_info.context.instance, asset_keys=[node.assetKey for node in results]
)

depended_by_loader = CrossRepoAssetDependedByLoader(context=graphene_info.context)
return [
GrapheneAssetNode(
node.repository_location,
node.external_repository,
node.external_asset_node,
materialization_loader=materialization_loader,
depended_by_loader=depended_by_loader,
)
for node in results
]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# pylint: disable=redefined-outer-name
from dagster import AssetGroup, AssetKey, SourceAsset, asset, repository


@asset
def upstream_asset():
return 5


upstream_asset_group = AssetGroup([upstream_asset])


@repository
def upstream_assets_repository():
return [upstream_asset_group]


source_assets = [SourceAsset(AssetKey("upstream_asset"))]


@asset
def downstream_asset1(upstream_asset):
assert upstream_asset


@asset
def downstream_asset2(upstream_asset):
assert upstream_asset


downstream_asset_group1 = AssetGroup(assets=[downstream_asset1], source_assets=source_assets)
downstream_asset_group2 = AssetGroup(assets=[downstream_asset2], source_assets=source_assets)


@repository
def downstream_assets_repository1():
return [downstream_asset_group1]


@repository
def downstream_assets_repository2():
return [downstream_asset_group2]

0 comments on commit a5c2854

Please sign in to comment.