Skip to content

Commit

Permalink
Merge pull request #1526 from fishtown-analytics/fix/refs-in-run-oper…
Browse files Browse the repository at this point in the history
…ations

Add ref and source support to run-operation macros (#1507)
  • Loading branch information
beckjake committed Jun 13, 2019
2 parents 16519b1 + 9a40395 commit 24adb74
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 100 deletions.
4 changes: 2 additions & 2 deletions core/dbt/adapters/base/impl.py
Expand Up @@ -880,8 +880,8 @@ def execute_macro(self, macro_name, manifest=None, project=None,
)
# This causes a reference cycle, as dbt.context.runtime.generate()
# ends up calling get_adapter, so the import has to be here.
import dbt.context.runtime
macro_context = dbt.context.runtime.generate_macro(
import dbt.context.operation
macro_context = dbt.context.operation.generate(
macro,
self.config,
manifest
Expand Down
16 changes: 16 additions & 0 deletions core/dbt/context/common.py
Expand Up @@ -68,6 +68,22 @@ def commit(self):
return self.adapter.commit_if_has_connection()


class BaseResolver(object):
def __init__(self, db_wrapper, model, config, manifest):
self.db_wrapper = db_wrapper
self.model = model
self.config = config
self.manifest = manifest

@property
def current_project(self):
return self.config.project_name

@property
def Relation(self):
return self.db_wrapper.Relation


def _add_macro_map(context, package_name, macro_map):
"""Update an existing context in-place, adding the given macro map to the
appropriate package namespace. Adapter packages get inserted into the
Expand Down
29 changes: 29 additions & 0 deletions core/dbt/context/operation.py
@@ -0,0 +1,29 @@
import dbt.context.common
from dbt.context import runtime
from dbt.exceptions import raise_compiler_error


class RefResolver(runtime.BaseRefResolver):
def __call__(self, *args):
# When you call ref(), this is what happens at operation runtime
target_model, name = self.resolve(args)
return self.create_relation(target_model, name)

def create_ephemeral_relation(self, target_model, name):
# In operations, we can't ref() ephemeral nodes, because ParsedMacros
# do not support set_cte
raise_compiler_error(
'Operations can not ref() ephemeral nodes, but {} is ephemeral'
.format(target_model.name),
self.model
)


class Provider(runtime.Provider):
ref = RefResolver


def generate(model, runtime_config, manifest):
return dbt.context.common.generate_execute_macro(
model, runtime_config, manifest, Provider()
)
60 changes: 33 additions & 27 deletions core/dbt/context/parser.py
Expand Up @@ -4,23 +4,6 @@
from dbt.adapters.factory import get_adapter


execute = False


def ref(db_wrapper, model, config, manifest):

def ref(*args):
if len(args) == 1 or len(args) == 2:
model.refs.append(list(args))

else:
dbt.exceptions.ref_invalid_args(model, args)

return db_wrapper.adapter.Relation.create_from_node(config, model)

return ref


def docs(unparsed, docrefs, column_name=None):

def do_docs(*args):
Expand All @@ -46,14 +29,6 @@ def do_docs(*args):
return do_docs


def source(db_wrapper, model, config, manifest):
def do_source(source_name, table_name):
model.sources.append([source_name, table_name])
return db_wrapper.adapter.Relation.create_from_node(config, model)

return do_source


class Config(object):
def __init__(self, model, source_config):
self.model = model
Expand Down Expand Up @@ -123,18 +98,49 @@ def get_missing_var(self, var_name):
return None


class RefResolver(dbt.context.common.BaseResolver):
def __call__(self, *args):
# When you call ref(), this is what happens at parse time
if len(args) == 1 or len(args) == 2:
self.model.refs.append(list(args))

else:
dbt.exceptions.ref_invalid_args(self.model, args)

return self.Relation.create_from_node(self.config, self.model)


class SourceResolver(dbt.context.common.BaseResolver):
def __call__(self, source_name, table_name):
# When you call source(), this is what happens at parse time
self.model.sources.append([source_name, table_name])
return self.Relation.create_from_node(self.config, self.model)


class Provider(object):
execute = False
Config = Config
DatabaseWrapper = DatabaseWrapper
Var = Var
ref = RefResolver
source = SourceResolver


def generate(model, runtime_config, manifest, source_config):
# during parsing, we don't have a connection, but we might need one, so we
# have to acquire it.
# In the future, it would be nice to lazily open the connection, as in some
# projects it would be possible to parse without connecting to the db
with get_adapter(runtime_config).connection_named(model.get('name')):
return dbt.context.common.generate(
model, runtime_config, manifest, source_config, dbt.context.parser
model, runtime_config, manifest, source_config, Provider()
)


def generate_macro(model, runtime_config, manifest):
# parser.generate_macro is called by the get_${attr}_func family of Parser
# methods, which preparse and cache the generate_${attr}_name family of
# macros for use during parsing
return dbt.context.common.generate_execute_macro(
model, runtime_config, manifest, dbt.context.parser
model, runtime_config, manifest, Provider()
)
116 changes: 59 additions & 57 deletions core/dbt/context/runtime.py
Expand Up @@ -8,80 +8,79 @@
from dbt.logger import GLOBAL_LOGGER as logger # noqa


execute = True


def ref(db_wrapper, model, config, manifest):
current_project = config.project_name
adapter = db_wrapper.adapter

def do_ref(*args):
target_model_name = None
target_model_package = None
class BaseRefResolver(dbt.context.common.BaseResolver):
def resolve(self, args):
name = None
package = None

if len(args) == 1:
target_model_name = args[0]
name = args[0]
elif len(args) == 2:
target_model_package, target_model_name = args
package, name = args
else:
dbt.exceptions.ref_invalid_args(model, args)
dbt.exceptions.ref_invalid_args(self.model, args)

target_model = ParserUtils.resolve_ref(
manifest,
target_model_name,
target_model_package,
current_project,
model.get('package_name'))
self.manifest,
name,
package,
self.current_project,
self.model.package_name)

if target_model is None or target_model is ParserUtils.DISABLED:
dbt.exceptions.ref_target_not_found(
model,
target_model_name,
target_model_package)

target_model_id = target_model.get('unique_id')

if target_model_id not in model.get('depends_on', {}).get('nodes'):
dbt.exceptions.ref_bad_context(model,
target_model_name,
target_model_package)

is_ephemeral = (get_materialization(target_model) == 'ephemeral')

if is_ephemeral:
model.set_cte(target_model_id, None)
return adapter.Relation.create(
type=adapter.Relation.CTE,
identifier=add_ephemeral_model_prefix(
target_model_name)).quote(identifier=False)
self.model,
name,
package)
return target_model, name

def create_ephemeral_relation(self, target_model, name):
self.model.set_cte(target_model.unique_id, None)
return self.Relation.create(
type=self.Relation.CTE,
identifier=add_ephemeral_model_prefix(name)
).quote(identifier=False)

def create_relation(self, target_model, name):
if get_materialization(target_model) == 'ephemeral':
return self.create_ephemeral_relation(target_model, name)
else:
return adapter.Relation.create_from_node(config, target_model)
return self.Relation.create_from_node(self.config, target_model)

return do_ref

class RefResolver(BaseRefResolver):
def validate(self, resolved, args):
if resolved.unique_id not in self.model.depends_on.get('nodes'):
dbt.exceptions.ref_bad_context(self.model, args)

def source(db_wrapper, model, config, manifest):
current_project = config.project_name
def __call__(self, *args):
# When you call ref(), this is what happens at runtime
target_model, name = self.resolve(args)
self.validate(target_model, args)
return self.create_relation(target_model, name)

def do_source(source_name, table_name):

class SourceResolver(dbt.context.common.BaseResolver):
def resolve(self, source_name, table_name):
target_source = ParserUtils.resolve_source(
manifest,
self.manifest,
source_name,
table_name,
current_project,
model.get('package_name')
self.current_project,
self.model.package_name
)

if target_source is None:
dbt.exceptions.source_target_not_found(
model,
self.model,
source_name,
table_name)
return target_source

model.sources.append([source_name, table_name])
return db_wrapper.Relation.create_from_source(target_source)

return do_source
def __call__(self, source_name, table_name):
"""When you call source(), this is what happens at runtime"""
target_source = self.resolve(source_name, table_name)
return self.Relation.create_from_source(target_source)


class Config:
Expand Down Expand Up @@ -137,12 +136,15 @@ class Var(dbt.context.common.Var):
pass


def generate(model, runtime_config, manifest):
return dbt.context.common.generate(
model, runtime_config, manifest, None, dbt.context.runtime)
class Provider(object):
execute = True
Config = Config
DatabaseWrapper = DatabaseWrapper
Var = Var
ref = RefResolver
source = SourceResolver


def generate_macro(model, runtime_config, manifest):
return dbt.context.common.generate_execute_macro(
model, runtime_config, manifest, dbt.context.runtime
)
def generate(model, runtime_config, manifest):
return dbt.context.common.generate(
model, runtime_config, manifest, None, Provider())
9 changes: 3 additions & 6 deletions core/dbt/exceptions.py
Expand Up @@ -314,12 +314,9 @@ def ref_invalid_args(model, args):
model)


def ref_bad_context(model, target_model_name, target_model_package):
ref_string = "{{ ref('" + target_model_name + "') }}"

if target_model_package is not None:
ref_string = ("{{ ref('" + target_model_package +
"', '" + target_model_name + "') }}")
def ref_bad_context(model, args):
ref_args = ', '.join("'{}'".format(a) for a in args)
ref_string = '{{{{ ref({}) }}}}'.format(ref_args)

base_error_msg = """dbt was unable to infer all dependencies for the model "{model_name}".
This typically happens when ref() is placed within a conditional block.
Expand Down
3 changes: 1 addition & 2 deletions core/dbt/main.py
Expand Up @@ -625,8 +625,7 @@ def _build_run_operation_subparser(subparsers, base_subparser):
of dbt. Please use it with caution"""
)
sub.add_argument(
'--macro',
required=True,
'macro',
help="""
Specify the macro to invoke. dbt will call this macro with the
supplied arguments and then exit"""
Expand Down
11 changes: 9 additions & 2 deletions test/integration/042_sources_test/macros/macro.sql
@@ -1,7 +1,14 @@
{% macro override_me() -%}
{{ exceptions.raise_compiler_error('this is a bad macro') }}
{{ exceptions.raise_compiler_error('this is a bad macro') }}
{%- endmacro %}

{% macro happy_little_macro() -%}
{{ override_me() }}
{{ override_me() }}
{%- endmacro %}


{% macro vacuum_source(source_name, table_name) -%}
{% call statement('stmt', auto_begin=false, fetch_result=false) %}
vacuum {{ source(source_name, table_name) }}
{% endcall %}
{%- endmacro %}
17 changes: 16 additions & 1 deletion test/integration/042_sources_test/test_sources.py
Expand Up @@ -49,6 +49,14 @@ def run_dbt_with_vars(self, cmd, *args, **kwargs):


class TestSources(BaseSourcesTest):
@property
def project_config(self):
cfg = super(TestSources, self).project_config
cfg.update({
'macro-paths': ['macros'],
})
return cfg

def _create_schemas(self):
super(TestSources, self)._create_schemas()
self._create_schema_named(self.default_database,
Expand Down Expand Up @@ -138,7 +146,7 @@ def test_postgres_source_only_def(self):
self.assertTableDoesNotExist('nonsource_descendant')

@use_profile('postgres')
def test_source_childrens_parents(self):
def test_postgres_source_childrens_parents(self):
results = self.run_dbt_with_vars([
'run', '--models', '@source:test_source'
])
Expand All @@ -149,6 +157,13 @@ def test_source_childrens_parents(self):
)
self.assertTableDoesNotExist('nonsource_descendant')

@use_profile('postgres')
def test_postgres_run_operation_source(self):
kwargs = '{"source_name": "test_source", "table_name": "test_table"}'
self.run_dbt_with_vars([
'run-operation', 'vacuum_source', '--args', kwargs
])


class TestSourceFreshness(BaseSourcesTest):
def setUp(self):
Expand Down
Expand Up @@ -22,3 +22,10 @@
vacuum "{{ schema }}"."{{ table_name }}"
{% endcall %}
{% endmacro %}


{% macro vacuum_ref(ref_target) %}
{% call statement('stmt', auto_begin=false) %}
vacuum {{ ref(ref_target) }}
{% endcall %}
{% endmacro %}

0 comments on commit 24adb74

Please sign in to comment.