Skip to content

Commit

Permalink
[dagster-dbt] add config schema for dbt asset op (#11710)
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed Jan 17, 2023
1 parent c96fef1 commit 627f67c
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
23 changes: 23 additions & 0 deletions python_modules/libraries/dagster-dbt/dagster_dbt/asset_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@
get_dagster_logger,
op,
)
from dagster._config.field import Field
from dagster._config.field_utils import Permissive
from dagster._core.definitions.events import CoercibleToAssetKeyPrefix
from dagster._core.definitions.load_assets_from_modules import prefix_assets
from dagster._core.definitions.metadata import RawMetadataValue
from dagster._core.errors import DagsterInvalidSubsetError
from dagster._legacy import OpExecutionContext
from dagster._utils.backcompat import experimental_arg_warning
from dagster._utils.merger import deep_merge_dicts

from dagster_dbt.cli.types import DbtCliOutput
from dagster_dbt.cli.utils import execute_cli
Expand Down Expand Up @@ -365,6 +368,23 @@ def _get_dbt_op(
ins=ins,
out=outs,
required_resource_keys={dbt_resource_key},
config_schema=Field(
Permissive(
{
"select": Field(str, is_required=False),
"exclude": Field(str, is_required=False),
"vars": Field(dict, is_required=False),
"full_refresh": Field(bool, is_required=False),
}
),
default_value={},
description=(
"Keyword arguments to pass to the underlying dbt command. Additional arguments not"
" listed in the schema will be passed through as well, e.g. {'bool_flag': True,"
" 'string_flag': 'hi'} will result in the flags '--bool-flag --string-flag hi'"
" being passed into the underlying execution"
),
),
)
def _dbt_op(context):
dbt_output = None
Expand All @@ -391,6 +411,9 @@ def _dbt_op(context):
if partition_key_to_vars_fn:
kwargs["vars"] = partition_key_to_vars_fn(context.partition_key)

# merge in any additional kwargs from the config
kwargs = deep_merge_dicts(kwargs, context.op_config)

if use_build_command:
dbt_output = dbt_resource.build(**kwargs)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from dagster import (
AssetIn,
AssetKey,
DailyPartitionsDefinition,
FreshnessPolicy,
IOManager,
MetadataEntry,
ResourceDefinition,
asset,
io_manager,
materialize_to_memory,
repository,
)
from dagster._core.definitions import build_assets_job
Expand Down Expand Up @@ -268,8 +270,6 @@ def test_custom_freshness_policy():
def test_partitions(
dbt_seed, conn_string, test_project_dir, dbt_config_dir
): # pylint: disable=unused-argument
from dagster import DailyPartitionsDefinition, materialize_to_memory

def _partition_key_to_vars(partition_key: str):
if partition_key == "2022-01-02":
return {"fail_test": True}
Expand Down Expand Up @@ -550,6 +550,50 @@ def hanger2():
assert all_keys == expected_keys


@pytest.mark.parametrize(
"config,expected_asset_names",
[
({"exclude": "tag:not_a_tag"}, "ALL"),
(
{"select": "sort_by_calories"},
"sort_by_calories",
),
({"full_refresh": True}, "ALL"),
({"vars": {"my_var": "my_value", "another_var": 3, "a_third_var": True}}, "ALL"),
],
)
def test_op_config(
config, expected_asset_names, dbt_seed, conn_string, test_project_dir, dbt_config_dir
):
if expected_asset_names == "ALL":
expected_asset_names = (
"sort_by_calories,cold_schema/sort_cold_cereals_by_calories,"
"sort_hot_cereals_by_calories,subdir_schema/least_caloric"
)
manifest_path = file_relative_path(__file__, "sample_manifest.json")
with open(manifest_path, "r", encoding="utf8") as f:
manifest_json = json.load(f)

dbt_assets = load_assets_from_dbt_manifest(manifest_json)
result = materialize_to_memory(
assets=dbt_assets,
run_config={"ops": {"run_dbt_5ad73": {"config": config}}},
resources={
"dbt": dbt_cli_resource.configured(
{"project_dir": test_project_dir, "profiles_dir": dbt_config_dir}
)
},
)
assert result.success
all_keys = {
event.event_specific_data.materialization.asset_key
for event in result.all_events
if event.event_type_value == "ASSET_MATERIALIZATION"
}
expected_keys = {AssetKey(name.split("/")) for name in expected_asset_names.split(",")}
assert all_keys == expected_keys


@pytest.mark.parametrize("load_from_manifest", [True, False])
@pytest.mark.parametrize(
"select,exclude,expected_asset_names",
Expand Down

0 comments on commit 627f67c

Please sign in to comment.