diff --git a/tests/unit/task/test_run.py b/tests/unit/task/test_run.py index c689e8f41a..19a88f6aa8 100644 --- a/tests/unit/task/test_run.py +++ b/tests/unit/task/test_run.py @@ -3,10 +3,19 @@ import pytest +from dbt.adapters.postgres import PostgresAdapter +from dbt.artifacts.schemas.results import RunStatus +from dbt.artifacts.schemas.run import RunResult from dbt.config.runtime import RuntimeConfig +from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.nodes import ModelNode +from dbt.events.types import LogModelResult from dbt.flags import get_flags, set_from_args -from dbt.task.run import RunTask +from dbt.task.run import ModelRunner, RunTask from dbt.tests.util import safe_set_invocation_context +from dbt_common.events.base_types import EventLevel +from dbt_common.events.event_manager_client import add_callback_to_manager +from tests.utils import EventCatcher @pytest.mark.parametrize( @@ -50,3 +59,70 @@ def test_run_task_preserve_edges(): task.get_graph_queue() # when we get the graph queue, preserve_edges is True mock_node_selector.get_graph_queue.assert_called_with(mock_spec, True) + + +class TestModelRunner: + @pytest.fixture + def log_model_result_catcher(self) -> EventCatcher: + catcher = EventCatcher(event_to_catch=LogModelResult) + add_callback_to_manager(catcher.catch) + return catcher + + @pytest.fixture + def model_runner( + self, + postgres_adapter: PostgresAdapter, + table_model: ModelNode, + runtime_config: RuntimeConfig, + ) -> ModelRunner: + return ModelRunner( + config=runtime_config, + adapter=postgres_adapter, + node=table_model, + node_index=1, + num_nodes=1, + ) + + @pytest.fixture + def run_result(self, table_model: ModelNode) -> RunResult: + return RunResult( + status=RunStatus.Success, + timing=[], + thread_id="an_id", + execution_time=0, + adapter_response={}, + message="It did it", + failures=None, + node=table_model, + ) + + def test_print_result_line( + self, + log_model_result_catcher: EventCatcher, + model_runner: ModelRunner, + run_result: RunResult, + ) -> None: + # Check `print_result_line` with "successful" RunResult + model_runner.print_result_line(run_result) + assert len(log_model_result_catcher.caught_events) == 1 + assert log_model_result_catcher.caught_events[0].info.level == EventLevel.INFO + assert log_model_result_catcher.caught_events[0].data.status == run_result.message + + # reset event catcher + log_model_result_catcher.flush() + + # Check `print_result_line` with "error" RunResult + run_result.status = RunStatus.Error + model_runner.print_result_line(run_result) + assert len(log_model_result_catcher.caught_events) == 1 + assert log_model_result_catcher.caught_events[0].info.level == EventLevel.ERROR + assert log_model_result_catcher.caught_events[0].data.status == EventLevel.ERROR + + @pytest.mark.skip( + reason="Default and adapter macros aren't being appropriately populated, leading to a runtime error" + ) + def test_execute( + self, table_model: ModelNode, manifest: Manifest, model_runner: ModelRunner + ) -> None: + model_runner.execute(model=table_model, manifest=manifest) + # TODO: Assert that the model was executed diff --git a/tests/unit/utils/adapter.py b/tests/unit/utils/adapter.py index 06555b0e40..c760a27ba4 100644 --- a/tests/unit/utils/adapter.py +++ b/tests/unit/utils/adapter.py @@ -1,9 +1,22 @@ +import sys from unittest.mock import MagicMock import pytest +from pytest_mock import MockerFixture +from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters from dbt.adapters.postgres import PostgresAdapter from dbt.adapters.sql import SQLConnectionManager +from dbt.config.runtime import RuntimeConfig +from dbt.context.providers import generate_runtime_macro_context +from dbt.contracts.graph.manifest import ManifestStateCheck +from dbt.mp_context import get_mp_context +from dbt.parser.manifest import ManifestLoader + +if sys.version_info < (3, 9): + from typing import Generator +else: + from collections.abc import Generator @pytest.fixture @@ -19,3 +32,28 @@ def mock_adapter(mock_connection_manager: MagicMock) -> MagicMock: mock_adapter.connections = mock_connection_manager mock_adapter.clear_macro_resolver = MagicMock() return mock_adapter + + +@pytest.fixture +def postgres_adapter( + mocker: MockerFixture, runtime_config: RuntimeConfig +) -> Generator[PostgresAdapter, None, None]: + register_adapter(runtime_config, get_mp_context()) + adapter = get_adapter(runtime_config) + assert isinstance(adapter, PostgresAdapter) + + mocker.patch( + "dbt.parser.manifest.ManifestLoader.build_manifest_state_check" + ).return_value = ManifestStateCheck() + manifest = ManifestLoader.load_macros( + runtime_config, + adapter.connections.set_query_header, + base_macros_only=True, + ) + + adapter.set_macro_resolver(manifest) + adapter.set_macro_context_generator(generate_runtime_macro_context) + + yield adapter + adapter.cleanup_connections() + reset_adapters() diff --git a/tests/unit/utils/manifest.py b/tests/unit/utils/manifest.py index bcadf1ad2e..6ac1d8d1a0 100644 --- a/tests/unit/utils/manifest.py +++ b/tests/unit/utils/manifest.py @@ -17,6 +17,7 @@ WhereFilter, WhereFilterIntersection, ) +from dbt.artifacts.resources.types import ModelLanguage from dbt.artifacts.resources.v1.model import ModelConfig from dbt.contracts.files import AnySourceFile, FileHash from dbt.contracts.graph.manifest import Manifest, ManifestMetadata @@ -526,6 +527,13 @@ def macro_test_not_null() -> Macro: ) +@pytest.fixture +def macro_materialization_table_default() -> Macro: + macro = make_macro("dbt", "materialization_table_default", "SELECT 1") + macro.supported_languages = [ModelLanguage.sql] + return macro + + @pytest.fixture def macro_default_test_not_null() -> Macro: return make_macro("dbt", "default__test_not_null", "blabla") @@ -964,12 +972,14 @@ def macros( macro_default_test_unique, macro_test_not_null, macro_default_test_not_null, + macro_materialization_table_default, ) -> List[Macro]: return [ macro_test_unique, macro_default_test_unique, macro_test_not_null, macro_default_test_not_null, + macro_materialization_table_default, ]