Skip to content

Commit

Permalink
autodiscover assets at module scope (#7247)
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed Apr 12, 2022
1 parent 6d5eae5 commit 7472843
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 44 deletions.
48 changes: 24 additions & 24 deletions python_modules/dagster/dagster/core/code_pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,9 @@ def __new__(cls, python_file: str, fn_name: str, working_directory: Optional[str

def load_target(self) -> object:
module = load_python_file(self.python_file, self.working_directory)
if not hasattr(module, self.fn_name):
raise DagsterInvariantViolationError(
"{name} not found at module scope in file {file}.".format(
name=self.fn_name, file=self.python_file
)
)

return getattr(module, self.fn_name)
return _load_target_from_module(
module, self.fn_name, f"at module scope in file {self.python_file}."
)

def describe(self) -> str:
if self.working_directory:
Expand All @@ -192,6 +187,21 @@ def describe(self) -> str:
return "{self.python_file}::{self.fn_name}".format(self=self)


def _load_target_from_module(module: ModuleType, fn_name: str, error_suffix: str) -> object:
from dagster.core.asset_defs import AssetGroup
from dagster.core.workspace.autodiscovery import LOAD_ALL_ASSETS

if fn_name == LOAD_ALL_ASSETS:
# LOAD_ALL_ASSETS is a special symbol that's returned when, instead of loading a particular
# attribute, we should load all the assets in the module.
return AssetGroup.from_modules([module])
else:
if not hasattr(module, fn_name):
raise DagsterInvariantViolationError(f"{fn_name} not found {error_suffix}")

return getattr(module, fn_name)


@whitelist_for_serdes
class ModuleCodePointer(
NamedTuple(
Expand All @@ -210,14 +220,9 @@ def __new__(cls, module: str, fn_name: str, working_directory: Optional[str] = N

def load_target(self) -> object:
module = load_python_module(self.module, self.working_directory)

if not hasattr(module, self.fn_name):
raise DagsterInvariantViolationError(
"{name} not found in module {module}. dir: {dir}".format(
name=self.fn_name, module=self.module, dir=dir(module)
)
)
return getattr(module, self.fn_name)
return _load_target_from_module(
module, self.fn_name, f"in module {self.module}. dir: {dir(module)}"
)

def describe(self) -> str:
return "from {self.module} import {self.fn_name}".format(self=self)
Expand All @@ -241,14 +246,9 @@ def __new__(cls, module: str, attribute: str, working_directory: Optional[str] =

def load_target(self) -> object:
module = load_python_module(self.module, self.working_directory)

if not hasattr(module, self.attribute):
raise DagsterInvariantViolationError(
"{name} not found in module {module}. dir: {dir}".format(
name=self.attribute, module=self.module, dir=dir(module)
)
)
return getattr(module, self.attribute)
return _load_target_from_module(
module, self.attribute, f"in module {self.module}. dir: {dir(module)}"
)

def describe(self) -> str:
return "from {self.module} import {self.attribute}".format(self=self)
Expand Down
30 changes: 19 additions & 11 deletions python_modules/dagster/dagster/core/workspace/autodiscovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from dagster.core.asset_defs import AssetGroup
from dagster.core.code_pointer import load_python_file, load_python_module

LOAD_ALL_ASSETS = "<<LOAD_ALL_ASSETS>>"


class LoadableTarget(NamedTuple):
attribute: str
Expand Down Expand Up @@ -83,10 +85,10 @@ def loadable_targets_from_loaded_module(module: ModuleType) -> Sequence[Loadable
elif len(loadable_graphs) > 1:
raise DagsterInvariantViolationError(
(
'No repository, job, or pipeline, and more than one graph found in "{module_name}". '
"If you load a file or module directly it must either have one repository, one "
"job, one pipeline, or one graph in scope. Found graphs defined in variables or "
"decorated functions: {graph_symbols}."
'More than one graph found in "{module_name}". '
"If you load a file or module directly and it has no repositories, jobs, or "
"pipelines in scope, it must have no more than one graph in scope. "
"Found graphs defined in variables or decorated functions: {graph_symbols}."
).format(
module_name=module.__name__,
graph_symbols=repr([g.attribute for g in loadable_graphs]),
Expand All @@ -101,17 +103,23 @@ def loadable_targets_from_loaded_module(module: ModuleType) -> Sequence[Loadable
var_names = repr([a.attribute for a in loadable_asset_groups])
raise DagsterInvariantViolationError(
(
f'More than one asset collection found in "{module.__name__}". '
"If you load a file or module directly it must either have one repository, one "
"job, one pipeline, one graph, or one asset collection scope. Found asset "
f"collections defined in variables: {var_names}."
f'More than one asset group found in "{module.__name__}". '
"If you load a file or module directly and it has no repositories, jobs, "
"pipeline, or graphs in scope, it must have no more than one asset group in scope. "
f"Found asset groups defined in variables: {var_names}."
)
)

asset_group_from_module_assets = AssetGroup.from_modules([module])
if (
len(asset_group_from_module_assets.assets) > 0
or len(asset_group_from_module_assets.source_assets) > 0
):
return [LoadableTarget(LOAD_ALL_ASSETS, asset_group_from_module_assets)]

raise DagsterInvariantViolationError(
'No jobs, pipelines, graphs, asset collections, or repositories found in "{}".'.format(
module.__name__
)
"No repositories, jobs, pipelines, graphs, asset groups, or asset definitions found in "
f'"{module.__name__}".'
)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# pylint: disable=redefined-outer-name
from dagster import AssetKey, SourceAsset, asset

source_asset = SourceAsset(AssetKey("source_asset"))


@asset
def asset1(source_asset):
assert source_asset


@asset
def asset2():
pass
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dagster.core.definitions.reconstruct import repository_def_from_pointer
from dagster.core.errors import DagsterImportError
from dagster.core.workspace.autodiscovery import (
LOAD_ALL_ASSETS,
loadable_targets_from_python_file,
loadable_targets_from_python_module,
loadable_targets_from_python_package,
Expand Down Expand Up @@ -87,10 +88,10 @@ def test_double_graph():
loadable_targets_from_python_file(double_pipeline_path)

assert str(exc_info.value) == (
'No repository, job, or pipeline, and more than one graph found in "double_graph". '
"If you load a file or module directly it must either have one repository, "
"one job, one pipeline, or one graph in scope. Found graphs defined in variables or decorated "
"functions: ['graph_one', 'graph_two']."
'More than one graph found in "double_graph". '
"If you load a file or module directly and it has no repositories, jobs, or "
"pipelines in scope, it must have no more than one graph in scope. "
"Found graphs defined in variables or decorated functions: ['graph_one', 'graph_two']."
)


Expand All @@ -115,20 +116,35 @@ def test_double_asset_group():
loadable_targets_from_python_file(path)

assert str(exc_info.value) == (
'More than one asset collection found in "double_asset_group". '
"If you load a file or module directly it must either have one repository, one "
"job, one pipeline, one graph, or one asset collection scope. Found asset "
"collections defined in variables: ['ac1', 'ac2']."
'More than one asset group found in "double_asset_group". '
"If you load a file or module directly and it has no repositories, jobs, "
"pipeline, or graphs in scope, it must have no more than one asset group in scope. "
"Found asset groups defined in variables: ['ac1', 'ac2']."
)


def test_multiple_assets():
path = file_relative_path(__file__, "multiple_assets.py")
loadable_targets = loadable_targets_from_python_file(path)

assert len(loadable_targets) == 1
symbol = loadable_targets[0].attribute
assert symbol == LOAD_ALL_ASSETS

repo_def = repository_def_from_pointer(CodePointer.from_python_file(path, symbol, None))

isinstance(repo_def, RepositoryDefinition)
the_job = repo_def.get_job("__ASSET_GROUP")
assert len(the_job.graph.node_defs) == 2


def test_no_loadable_targets():
with pytest.raises(DagsterInvariantViolationError) as exc_info:
loadable_targets_from_python_file(file_relative_path(__file__, "nada.py"))

assert (
str(exc_info.value)
== 'No jobs, pipelines, graphs, asset collections, or repositories found in "nada".'
== 'No repositories, jobs, pipelines, graphs, asset groups, or asset definitions found in "nada".'
)


Expand Down

0 comments on commit 7472843

Please sign in to comment.