Skip to content

Commit

Permalink
[dagster-dbt] allow for static manifest.json-based selection (#8047)
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed May 25, 2022
1 parent 220e74d commit d7129a5
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 18 deletions.
76 changes: 61 additions & 15 deletions python_modules/libraries/dagster-dbt/dagster_dbt/asset_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dagster import _check as check
from dagster import get_dagster_logger, op
from dagster.core.definitions.metadata import RawMetadataValue
from dagster.core.errors import DagsterInvalidSubsetError


def _load_manifest_for_project(
Expand Down Expand Up @@ -52,6 +53,50 @@ def _load_manifest_for_project(
return json.load(f), cli_output


def _select_unique_ids_from_manifest_json(
manifest_json: Mapping[str, Any], select: str
) -> AbstractSet[str]:
"""Method to apply a selection string to an existing manifest.json file."""
try:
import dbt.graph.cli as graph_cli
import dbt.graph.selector as graph_selector
from dbt.contracts.graph.manifest import Manifest
from networkx import DiGraph
except ImportError:
check.failed(
"In order to use the `select` argument on load_assets_from_manifest_json, you must have"
"`dbt-core >= 1.0.0` and `networkx` installed."
)

class _DictShim(dict):
"""Shim to enable hydrating a dictionary into a dot-accessible object"""

def __getattr__(self, item):
ret = super().get(item)
# allow recursive access e.g. foo.bar.baz
return _DictShim(ret) if isinstance(ret, dict) else ret

# generate a dbt-compatible graph from the existing child map
graph = graph_selector.Graph(DiGraph(incoming_graph_data=manifest_json["child_map"]))
manifest = Manifest(
# dbt expects dataclasses that can be accessed with dot notation, not bare dictionaries
nodes={unique_id: _DictShim(info) for unique_id, info in manifest_json["nodes"].items()},
sources={
unique_id: _DictShim(info) for unique_id, info in manifest_json["sources"].items()
},
)

# create a parsed selection from the select string
parsed_spec = graph_cli.parse_union([select], True)

# execute this selection against the graph
selector = graph_selector.NodeSelector(graph, manifest)
selected, _ = selector.select_nodes(parsed_spec)
if len(selected) == 0:
raise DagsterInvalidSubsetError(f"No dbt models match the selection string '{select}'.")
return selected


def _get_node_name(node_info: Mapping[str, Any]):
return "__".join([node_info["resource_type"], node_info["package_name"], node_info["name"]])

Expand Down Expand Up @@ -97,12 +142,13 @@ def _dbt_nodes_to_assets(
for unique_id in selected_unique_ids:
cur_asset_deps = set()
node_info = dbt_nodes[unique_id]
if node_info["resource_type"] != "model":
continue
package_name = node_info.get("package_name", package_name)

for dep_name in node_info["depends_on"]["nodes"]:
dep_type = dbt_nodes[dep_name]["resource_type"]

# ignore seeds/snapshots
# ignore seeds/snapshots/tests
if dep_type not in ["source", "model"]:
continue
dep_asset_key = node_info_to_asset_key(dbt_nodes[dep_name])
Expand Down Expand Up @@ -299,6 +345,7 @@ def load_assets_from_dbt_manifest(
] = None,
io_manager_key: Optional[str] = None,
selected_unique_ids: Optional[AbstractSet[str]] = None,
select: Optional[str] = None,
node_info_to_asset_key: Callable[[Mapping[str, Any]], AssetKey] = _get_node_asset_key,
use_build_command: bool = False,
) -> Sequence[AssetsDefinition]:
Expand Down Expand Up @@ -328,20 +375,19 @@ def load_assets_from_dbt_manifest(
check.dict_param(manifest_json, "manifest_json", key_type=str)
dbt_nodes = {**manifest_json["nodes"], **manifest_json["sources"]}

def _unique_id_to_selector(uid):
# take the fully-qualified node name and use it to select the model
return ".".join(dbt_nodes[uid]["fqn"])
if select is None:
if selected_unique_ids:
# generate selection string from unique ids
select = " ".join(".".join(dbt_nodes[uid]["fqn"]) for uid in selected_unique_ids)
else:
# if no selection specified, default to "*"
select = "*"
selected_unique_ids = manifest_json["nodes"].keys()

if selected_unique_ids is None:
# must resolve the selection string using the existing manifest.json data (hacky)
selected_unique_ids = _select_unique_ids_from_manifest_json(manifest_json, select)

select = (
"*"
if selected_unique_ids is None
else " ".join(_unique_id_to_selector(uid) for uid in selected_unique_ids)
)
selected_unique_ids = selected_unique_ids or set(
unique_id
for unique_id, node_info in dbt_nodes.items()
if node_info["resource_type"] == "model"
)
return [
_dbt_nodes_to_assets(
dbt_nodes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ def prepare_dbt_cli(conn_string): # pylint: disable=unused-argument, redefined-
yield


@pytest.fixture(scope="class")
@pytest.fixture(scope="session")
def dbt_seed(
prepare_dbt_cli, dbt_executable, dbt_config_dir
): # pylint: disable=unused-argument, redefined-outer-name
subprocess.run([dbt_executable, "seed", "--profiles-dir", dbt_config_dir], check=True)


@pytest.fixture(scope="class")
@pytest.fixture(scope="session")
def dbt_build(
prepare_dbt_cli, dbt_executable, dbt_config_dir
): # pylint: disable=unused-argument, redefined-outer-name
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{{ config(tags=["foo", "bar"]) }}
SELECT *
from "test-schema".cereals
ORDER BY calories
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{{ config(tags=["foo"]) }}
SELECT *
FROM {{ ref('sort_by_calories') }}
WHERE type='C'
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{{ config(tags=["bar"]) }}
SELECT *
FROM {{ ref('sort_by_calories') }}
WHERE type='H'
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
{{ config(tags=["bar"]) }}
SELECT * from {{ ref('sort_by_calories') }} LIMIT 1

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,101 @@ def hanger2(least_caloric):
}
expected_keys = {AssetKey(["test-schema", name]) for name in expected_asset_names.split(",")}
assert all_keys == expected_keys


@pytest.mark.parametrize("load_from_manifest", [True, False])
@pytest.mark.parametrize(
"select,expected_asset_names",
[
(
"*",
{
"sort_by_calories",
"sort_cold_cereals_by_calories",
"least_caloric",
"sort_hot_cereals_by_calories",
},
),
(
"+least_caloric",
{"sort_by_calories", "least_caloric"},
),
(
"sort_by_calories least_caloric",
{"sort_by_calories", "least_caloric"},
),
(
"tag:bar+",
{
"sort_by_calories",
"sort_cold_cereals_by_calories",
"least_caloric",
"sort_hot_cereals_by_calories",
},
),
(
"tag:foo",
{"sort_by_calories", "sort_cold_cereals_by_calories"},
),
(
"tag:foo,tag:bar",
{"sort_by_calories"},
),
],
)
def test_dbt_selects(
dbt_build,
conn_string,
test_project_dir,
dbt_config_dir,
load_from_manifest,
select,
expected_asset_names,
): # pylint: disable=unused-argument
if load_from_manifest:
manifest_path = file_relative_path(__file__, "sample_manifest.json")
with open(manifest_path, "r", encoding="utf8") as f:
manifest_json = json.load(f)

dbt_assets = load_assets_from_dbt_manifest(manifest_json, select=select)
else:
dbt_assets = load_assets_from_dbt_project(
project_dir=test_project_dir, profiles_dir=dbt_config_dir, select=select
)

expected_asset_keys = {AssetKey(["test-schema", key]) for key in expected_asset_names}
assert dbt_assets[0].asset_keys == expected_asset_keys

result = (
AssetGroup(
dbt_assets,
resource_defs={
"dbt": dbt_cli_resource.configured(
{"project_dir": test_project_dir, "profiles_dir": dbt_config_dir}
)
},
)
.build_job(name="dbt_job")
.execute_in_process()
)

assert result.success
all_keys = {
event.event_specific_data.materialization.asset_key
for event in result.all_events
if event.event_type_value == "ASSET_MATERIALIZATION"
}
assert all_keys == expected_asset_keys


@pytest.mark.parametrize(
"select,error_match",
[("tag:nonexist", "No dbt models match"), ("asjdlhalskujh:z", "not a valid method name")],
)
def test_static_select_invalid_selection(select, error_match):
manifest_path = file_relative_path(__file__, "sample_manifest.json")
with open(manifest_path, "r", encoding="utf8") as f:
manifest_json = json.load(f)

with pytest.raises(Exception, match=error_match):
load_assets_from_dbt_manifest(manifest_json, select=select)

0 comments on commit d7129a5

Please sign in to comment.