Skip to content

Commit

Permalink
Add initial unit tests for ModelRunner class (#10196)
Browse files Browse the repository at this point in the history
* Add unit test for `ModelRunner.print_result_line`

* Add (and skip) unit test for `ModelRunner.execute`

An attempt at testing `ModelRunner.execute`. We should probably also be
asserting that the model has been executed. However before we could get there,
we're running into runtime errors during `ModelRunner.execute`. Currently the
struggle is ensuring the adapter exists in the global factory when `execute`
goes looking for. The error we're getting looks like the following:
```
    def test_execute(self, table_model: ModelNode, manifest: Manifest, model_runner: ModelRunner) -> None:
>       model_runner.execute(model=table_model, manifest=manifest)

tests/unit/task/test_run.py:121:
----
core/dbt/task/run.py:259: in execute
    context = generate_runtime_model_context(model, self.config, manifest)
core/dbt/context/providers.py:1636: in generate_runtime_model_context
    ctx = ModelContext(model, config, manifest, RuntimeProvider(), None)
core/dbt/context/providers.py:834: in __init__
    self.adapter = get_adapter(self.config)
venv/lib/python3.10/site-packages/dbt/adapters/factory.py:207: in get_adapter
    return FACTORY.lookup_adapter(config.credentials.type)
---`
self = <dbt.adapters.factory.AdapterContainer object at 0x106e73280>, adapter_name = 'postgres'

    def lookup_adapter(self, adapter_name: str) -> Adapter:
>       return self.adapters[adapter_name]
E       KeyError: 'postgres'

venv/lib/python3.10/site-packages/dbt/adapters/factory.py:132: KeyError
```

* Add `postgres_adapter` fixture for use in `TestModelRunner`

Previously we were running into an issue where the during `ModelRunner.execute`
the mock_adapter we were using wouldn't be found in the global adapter
factory. We've gotten past this error by supply a "real" adapter, a
`PostgresAdapter` instance. However we're now running into a new error
in which the materialization macro can't be found. This error looks like
```
model_runner = <dbt.task.run.ModelRunner object at 0x106746650>

    def test_execute(
        self, table_model: ModelNode, manifest: Manifest, model_runner: ModelRunner
    ) -> None:
>       model_runner.execute(model=table_model, manifest=manifest)

tests/unit/task/test_run.py:129:
----
self = <dbt.task.run.ModelRunner object at 0x106746650>
model = ModelNode(database='dbt', schema='dbt_schema', name='table_model', resource_type=<NodeType.Model: 'model'>, package_na...ected'>, constraints=[], version=None, latest_version=None, deprecation_date=None, defer_relation=None, primary_key=[])
manifest = Manifest(nodes={'seed.pkg.seed': SeedNode(database='dbt', schema='dbt_schema', name='seed', resource_type=<NodeType.Se...s(show=True, node_color=None), patch_path=None, arguments=[], created_at=1718229810.21914, supported_languages=None)}})

    def execute(self, model, manifest):
        context = generate_runtime_model_context(model, self.config, manifest)

        materialization_macro = manifest.find_materialization_macro_by_name(
            self.config.project_name, model.get_materialization(), self.adapter.type()
        )

        if materialization_macro is None:
>           raise MissingMaterializationError(
                materialization=model.get_materialization(), adapter_type=self.adapter.type()
            )
E           dbt.adapters.exceptions.compilation.MissingMaterializationError: Compilation Error
E             No materialization 'table' was found for adapter postgres! (searched types 'default' and 'postgres')

core/dbt/task/run.py:266: MissingMaterializationError
```

* Add spoofed macro fixture `materialization_table_default` for `test_execute` test

Previously the `TestModelRunner:test_execute` test was running into a runtime error
do to the macro `materialization_table_default` macro not existing in the project. This
commit adds that macro to the project (though it should ideally get loaded via interactions
between the manifest and adapter). Manually adding it resolved our previous issue, but created
a new one. The macro appears to not be properly loaded into the manifest, and thus isn't
discoverable later on when getting the macros for the jinja context. This leads to an error
that looks like the following:
```
model_runner = <dbt.task.run.ModelRunner object at 0x1080a4f70>

    def test_execute(
        self, table_model: ModelNode, manifest: Manifest, model_runner: ModelRunner
    ) -> None:
>       model_runner.execute(model=table_model, manifest=manifest)

tests/unit/task/test_run.py:129:
----
core/dbt/task/run.py:287: in execute
    result = MacroGenerator(
core/dbt/clients/jinja.py:82: in __call__
    return self.call_macro(*args, **kwargs)
venv/lib/python3.10/site-packages/dbt_common/clients/jinja.py:294: in call_macro
    macro = self.get_macro()
---
self = <dbt.clients.jinja.MacroGenerator object at 0x1080f3130>

    def get_macro(self):
        name = self.get_name()
        template = self.get_template()
        # make the module. previously we set both vars and local, but that's
        # redundant: They both end up in the same place
        # make_module is in jinja2.environment. It returns a TemplateModule
        module = template.make_module(vars=self.context, shared=False)
>       macro = module.__dict__[get_dbt_macro_name(name)]
E       KeyError: 'dbt_macro__materialization_table_default'

venv/lib/python3.10/site-packages/dbt_common/clients/jinja.py:277: KeyError
```

It's becoming apparent that we need to find a better way to either mock or legitimately
load the default and adapter macros. At this point I think I've exausted the time box
I should be using to figure out if testing the `ModelRunner` class is possible currently,
with the result being more work has yet to be done.

* Begin adding the `LogModelResult` event catcher to event manager class fixture
  • Loading branch information
QMalcolm committed Jun 18, 2024
1 parent 1475abb commit da19d7b
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 1 deletion.
78 changes: 77 additions & 1 deletion tests/unit/task/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,19 @@

import pytest

from dbt.adapters.postgres import PostgresAdapter
from dbt.artifacts.schemas.results import RunStatus
from dbt.artifacts.schemas.run import RunResult
from dbt.config.runtime import RuntimeConfig
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ModelNode
from dbt.events.types import LogModelResult
from dbt.flags import get_flags, set_from_args
from dbt.task.run import RunTask
from dbt.task.run import ModelRunner, RunTask
from dbt.tests.util import safe_set_invocation_context
from dbt_common.events.base_types import EventLevel
from dbt_common.events.event_manager_client import add_callback_to_manager
from tests.utils import EventCatcher


@pytest.mark.parametrize(
Expand Down Expand Up @@ -50,3 +59,70 @@ def test_run_task_preserve_edges():
task.get_graph_queue()
# when we get the graph queue, preserve_edges is True
mock_node_selector.get_graph_queue.assert_called_with(mock_spec, True)


class TestModelRunner:
@pytest.fixture
def log_model_result_catcher(self) -> EventCatcher:
catcher = EventCatcher(event_to_catch=LogModelResult)
add_callback_to_manager(catcher.catch)
return catcher

@pytest.fixture
def model_runner(
self,
postgres_adapter: PostgresAdapter,
table_model: ModelNode,
runtime_config: RuntimeConfig,
) -> ModelRunner:
return ModelRunner(
config=runtime_config,
adapter=postgres_adapter,
node=table_model,
node_index=1,
num_nodes=1,
)

@pytest.fixture
def run_result(self, table_model: ModelNode) -> RunResult:
return RunResult(
status=RunStatus.Success,
timing=[],
thread_id="an_id",
execution_time=0,
adapter_response={},
message="It did it",
failures=None,
node=table_model,
)

def test_print_result_line(
self,
log_model_result_catcher: EventCatcher,
model_runner: ModelRunner,
run_result: RunResult,
) -> None:
# Check `print_result_line` with "successful" RunResult
model_runner.print_result_line(run_result)
assert len(log_model_result_catcher.caught_events) == 1
assert log_model_result_catcher.caught_events[0].info.level == EventLevel.INFO
assert log_model_result_catcher.caught_events[0].data.status == run_result.message

# reset event catcher
log_model_result_catcher.flush()

# Check `print_result_line` with "error" RunResult
run_result.status = RunStatus.Error
model_runner.print_result_line(run_result)
assert len(log_model_result_catcher.caught_events) == 1
assert log_model_result_catcher.caught_events[0].info.level == EventLevel.ERROR
assert log_model_result_catcher.caught_events[0].data.status == EventLevel.ERROR

@pytest.mark.skip(
reason="Default and adapter macros aren't being appropriately populated, leading to a runtime error"
)
def test_execute(
self, table_model: ModelNode, manifest: Manifest, model_runner: ModelRunner
) -> None:
model_runner.execute(model=table_model, manifest=manifest)
# TODO: Assert that the model was executed
38 changes: 38 additions & 0 deletions tests/unit/utils/adapter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import sys
from unittest.mock import MagicMock

import pytest
from pytest_mock import MockerFixture

from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters
from dbt.adapters.postgres import PostgresAdapter
from dbt.adapters.sql import SQLConnectionManager
from dbt.config.runtime import RuntimeConfig
from dbt.context.providers import generate_runtime_macro_context
from dbt.contracts.graph.manifest import ManifestStateCheck
from dbt.mp_context import get_mp_context
from dbt.parser.manifest import ManifestLoader

if sys.version_info < (3, 9):
from typing import Generator
else:
from collections.abc import Generator


@pytest.fixture
Expand All @@ -19,3 +32,28 @@ def mock_adapter(mock_connection_manager: MagicMock) -> MagicMock:
mock_adapter.connections = mock_connection_manager
mock_adapter.clear_macro_resolver = MagicMock()
return mock_adapter


@pytest.fixture
def postgres_adapter(
mocker: MockerFixture, runtime_config: RuntimeConfig
) -> Generator[PostgresAdapter, None, None]:
register_adapter(runtime_config, get_mp_context())
adapter = get_adapter(runtime_config)
assert isinstance(adapter, PostgresAdapter)

mocker.patch(
"dbt.parser.manifest.ManifestLoader.build_manifest_state_check"
).return_value = ManifestStateCheck()
manifest = ManifestLoader.load_macros(
runtime_config,
adapter.connections.set_query_header,
base_macros_only=True,
)

adapter.set_macro_resolver(manifest)
adapter.set_macro_context_generator(generate_runtime_macro_context)

yield adapter
adapter.cleanup_connections()
reset_adapters()
10 changes: 10 additions & 0 deletions tests/unit/utils/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
WhereFilter,
WhereFilterIntersection,
)
from dbt.artifacts.resources.types import ModelLanguage
from dbt.artifacts.resources.v1.model import ModelConfig
from dbt.contracts.files import AnySourceFile, FileHash
from dbt.contracts.graph.manifest import Manifest, ManifestMetadata
Expand Down Expand Up @@ -526,6 +527,13 @@ def macro_test_not_null() -> Macro:
)


@pytest.fixture
def macro_materialization_table_default() -> Macro:
macro = make_macro("dbt", "materialization_table_default", "SELECT 1")
macro.supported_languages = [ModelLanguage.sql]
return macro


@pytest.fixture
def macro_default_test_not_null() -> Macro:
return make_macro("dbt", "default__test_not_null", "blabla")
Expand Down Expand Up @@ -964,12 +972,14 @@ def macros(
macro_default_test_unique,
macro_test_not_null,
macro_default_test_not_null,
macro_materialization_table_default,
) -> List[Macro]:
return [
macro_test_unique,
macro_default_test_unique,
macro_test_not_null,
macro_default_test_not_null,
macro_materialization_table_default,
]


Expand Down

0 comments on commit da19d7b

Please sign in to comment.