Skip to content

Commit

Permalink
asset defs directly on repository (#8197)
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed Jun 8, 2022
1 parent a1a31d4 commit eea9a77
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, name: Optional[str] = None, description: Optional[str] = None
self.description = check.opt_str_param(description, "description")

def __call__(self, fn: Callable[[], Any]) -> RepositoryDefinition:
from dagster.core.asset_defs import AssetGroup
from dagster.core.asset_defs import AssetGroup, AssetsDefinition, SourceAsset

check.callable_param(fn, "fn")

Expand All @@ -43,6 +43,8 @@ def __call__(self, fn: Callable[[], Any]) -> RepositoryDefinition:
or isinstance(definition, SensorDefinition)
or isinstance(definition, GraphDefinition)
or isinstance(definition, AssetGroup)
or isinstance(definition, AssetsDefinition)
or isinstance(definition, SourceAsset)
):
bad_definitions.append((i, type(definition)))
if bad_definitions:
Expand All @@ -55,7 +57,8 @@ def __call__(self, fn: Callable[[], Any]) -> RepositoryDefinition:
raise DagsterInvalidDefinitionError(
"Bad return value from repository construction function: all elements of list "
"must be of type JobDefinition, GraphDefinition, PipelineDefinition, "
"PartitionSetDefinition, ScheduleDefinition, or SensorDefinition. "
"PartitionSetDefinition, ScheduleDefinition, SensorDefinition, "
"AssetsDefinition, or SourceAsset."
f"Got {bad_definitions_str}."
)
repository_data = CachingRepositoryData.from_list(repository_definitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def __init__(
],
schedules: Mapping[str, Union[ScheduleDefinition, Resolvable[ScheduleDefinition]]],
sensors: Mapping[str, Union[SensorDefinition, Resolvable[SensorDefinition]]],
source_assets: Mapping[AssetKey, SourceAsset],
source_assets_by_key: Mapping[AssetKey, SourceAsset],
):
"""Constructs a new CachingRepositoryData object.
Expand All @@ -458,7 +458,7 @@ def __init__(
The schedules belonging to the repository.
sensors (Mapping[str, Union[SensorDefinition, Callable[[], SensorDefinition]]]):
The sensors belonging to a repository.
source_assets (Mapping[AssetKey, SourceAsset]): The source assets belonging to a repository.
source_assets_by_key (Mapping[AssetKey, SourceAsset]): The source assets belonging to a repository.
"""
check.mapping_param(
pipelines, "pipelines", key_type=str, value_type=(PipelineDefinition, FunctionType)
Expand All @@ -477,7 +477,7 @@ def __init__(
sensors, "sensors", key_type=str, value_type=(SensorDefinition, FunctionType)
)
check.mapping_param(
source_assets, "source_assets", key_type=AssetKey, value_type=SourceAsset
source_assets_by_key, "source_assets_by_key", key_type=AssetKey, value_type=SourceAsset
)

self._pipelines = _CacheingDefinitionIndex(
Expand Down Expand Up @@ -508,7 +508,7 @@ def __init__(
for schedule in self._schedules.get_all_definitions()
if isinstance(schedule, PartitionScheduleDefinition)
]
self._source_assets = source_assets
self._source_assets_by_key = source_assets_by_key

def load_partition_sets_from_pipelines() -> List[PartitionSetDefinition]:
job_partition_sets = []
Expand Down Expand Up @@ -608,7 +608,7 @@ def from_dict(repository_definitions: Dict[str, Dict[str, Any]]) -> "CachingRepo
f"Object mapped to {key} is not an instance of JobDefinition or GraphDefinition."
)

return CachingRepositoryData(**repository_definitions, source_assets={})
return CachingRepositoryData(**repository_definitions, source_assets_by_key={})

@classmethod
def from_list(
Expand All @@ -631,14 +631,15 @@ def from_list(
Use this constructor when you have no need to lazy load pipelines/jobs or other
definitions.
"""
from dagster.core.asset_defs import AssetGroup
from dagster.core.asset_defs import AssetGroup, AssetsDefinition

pipelines_or_jobs: Dict[str, Union[PipelineDefinition, JobDefinition]] = {}
coerced_graphs: Dict[str, JobDefinition] = {}
partition_sets: Dict[str, PartitionSetDefinition] = {}
schedules: Dict[str, ScheduleDefinition] = {}
sensors: Dict[str, SensorDefinition] = {}
source_assets: Dict[AssetKey, SourceAsset] = {}
assets_defs: List[AssetsDefinition] = []
source_assets: List[SourceAsset] = []
combined_asset_group = None
for definition in repository_definitions:
if isinstance(definition, PipelineDefinition):
Expand Down Expand Up @@ -697,23 +698,35 @@ def from_list(
)
pipelines_or_jobs[coerced.name] = coerced
coerced_graphs[coerced.name] = coerced

elif isinstance(definition, AssetGroup):
if combined_asset_group:
combined_asset_group += definition
else:
combined_asset_group = definition
elif isinstance(definition, AssetsDefinition):
assets_defs.append(definition)
elif isinstance(definition, SourceAsset):
source_assets.append(definition)
else:
check.failed(f"Unexpected repository entry {definition}")

if assets_defs or source_assets:
if combined_asset_group is not None:
raise DagsterInvalidDefinitionError(
"A repository can't have both an AssetGroup and direct asset defs"
)
combined_asset_group = AssetGroup(assets=assets_defs, source_assets=source_assets)

if combined_asset_group:
for job_def in combined_asset_group.get_base_jobs():
pipelines_or_jobs[job_def.name] = job_def

source_assets = {
source_assets_by_key = {
source_asset.key: source_asset
for source_asset in combined_asset_group.source_assets
}
else:
source_assets_by_key = {}

for name, sensor_def in sensors.items():
if sensor_def.has_loadable_targets():
Expand Down Expand Up @@ -744,7 +757,7 @@ def from_list(
partition_sets=partition_sets,
schedules=schedules,
sensors=sensors,
source_assets=source_assets,
source_assets_by_key=source_assets_by_key,
)

def get_pipeline_names(self) -> List[str]:
Expand Down Expand Up @@ -965,7 +978,7 @@ def has_sensor(self, sensor_name: str) -> bool:
return self._sensors.has_definition(sensor_name)

def get_source_assets_by_key(self) -> Mapping[AssetKey, SourceAsset]:
return self._source_assets
return self._source_assets_by_key

def _check_solid_defs(self, pipelines: List[PipelineDefinition]) -> None:
solid_defs = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,28 @@ def my_repo():
}


def test_direct_assets():
foo = SourceAsset("foo")

@asset
def asset1():
...

@asset
def asset2():
...

@repository
def my_repo():
return [foo, asset1, asset2]

assert len(my_repo.get_all_jobs()) == 1
assert set(my_repo.get_all_jobs()[0].asset_layer.asset_keys) == {
AssetKey(["asset1"]),
AssetKey(["asset2"]),
}


def _create_graph_with_name(name):
@graph(name=name)
def _the_graph():
Expand Down

0 comments on commit eea9a77

Please sign in to comment.