Skip to content

Commit

Permalink
Improved flags fixturing for for repository unit tests (#10190)
Browse files Browse the repository at this point in the history
* Add fixtures for setting and resettign flags for unit tests

* Remove unnecessary `set_from_args` in non `unittest.TestCase` based unit tests

In the previous commit we added a pytest fixture which sets and tears down
the global flags arg via `set_from_args` for every pytest based unit test.
Previously we had added a `set_from_args` in tests or test files to reset
the global flags from if they were modified by a previous test. This is no
longer necessary because of the work done in the previous commit.

Note: We did not modify any tests that use the `unittest.TestCase` class
because they don't use pytest fixtures. Thus those tests need to continue
operating as they currently do until we shift them to pytest unit tests.

* Utilize the new `args_for_flags` fixture for setting of flags in `test_contracts_graph_parsed.py`

* Convert `test_compilation.py` from `TestCase` tests to pytest tests

We did this so in the next commit we can drop the unnecessary `set_from_args`
in the next commit. That will be it's own commit because converting these
tests is a restructuring that doing separately makes things easier to follow.
That is to say, all changes in this commit were just to convert the tests to
pytest, no other changes were made.

* Drop unnecessary `set_from_args` in `test_compilation.py`

* Add return types to all methods in `test_compilation.py`

* Reduce imports from `compilation` in `test_compilation.py`

* Update `test_logging.py` now that we don't need to worry about global flags

* Conditionally import `Generator` type for python 3.8

In python 3.9 `Generator` was moved to `collections.abc` and deprecated
in `typing`. We still support 3.8 and thus need to be conditionally
importing `Generator`. We should remove this in the future when we drop
support for 3.8.
  • Loading branch information
QMalcolm committed May 21, 2024
1 parent 0d297c2 commit 09243d1
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 116 deletions.
1 change: 1 addition & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# All manifest related fixtures.
from tests.unit.utils.adapter import * # noqa
from tests.unit.utils.event_manager import * # noqa
from tests.unit.utils.flags import * # noqa
from tests.unit.utils.manifest import * # noqa
from tests.unit.utils.project import * # noqa

Expand Down
4 changes: 0 additions & 4 deletions tests/unit/context/test_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from argparse import Namespace
from typing import Any, Dict, Set
from unittest import mock

Expand All @@ -19,14 +18,11 @@
UnitTestNode,
UnitTestOverrides,
)
from dbt.flags import set_from_args
from dbt.node_types import NodeType
from dbt_common.events.functions import reset_metadata_vars
from tests.unit.mock_adapter import adapter_factory
from tests.unit.utils import clear_plugin, config_from_parts_or_dicts, inject_adapter

set_from_args(Namespace(WARN_ERROR=False), None)


class TestVar:
@pytest.fixture
Expand Down
4 changes: 0 additions & 4 deletions tests/unit/context/test_query_header.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import re
from argparse import Namespace
from unittest import mock

import pytest

from dbt.adapters.base.query_headers import MacroQueryStringSetter
from dbt.context.query_header import generate_query_header_context
from dbt.flags import set_from_args
from tests.unit.utils import config_from_parts_or_dicts

set_from_args(Namespace(WARN_ERROR=False), None)


class TestQueryHeaderContext:
@pytest.fixture
Expand Down
9 changes: 3 additions & 6 deletions tests/unit/events/test_logging.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from argparse import Namespace
from copy import deepcopy

from pytest_mock import MockerFixture

Expand All @@ -19,12 +18,10 @@ def test_clears_preexisting_event_manager_state(self) -> None:
assert len(manager.loggers) == 1
assert len(manager.callbacks) == 1

flags = deepcopy(get_flags())
# setting both of these to none guarantees that no logger will be added
object.__setattr__(flags, "LOG_LEVEL", "none")
object.__setattr__(flags, "LOG_LEVEL_FILE", "none")
args = Namespace(log_level="none", log_level_file="none")
set_from_args(args, {})

setup_event_logger(flags=flags)
setup_event_logger(get_flags())
assert len(manager.loggers) == 0
assert len(manager.callbacks) == 0

Expand Down
4 changes: 0 additions & 4 deletions tests/unit/parser/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def test_partial_parse_file_path(self, patched_open, patched_os_exist, patched_s
mock_project = MagicMock(RuntimeConfig)
mock_project.project_target_path = "mock_target_path"
patched_os_exist.return_value = True
set_from_args(Namespace(), {})
ManifestLoader(mock_project, {})
# by default we use the project_target_path
patched_open.assert_called_with("mock_target_path/partial_parse.msgpack", "rb")
Expand All @@ -33,7 +32,6 @@ def test_profile_hash_change(self, mock_project):
# This test validate that the profile_hash is updated when the connection keys change
profile_hash = "750bc99c1d64ca518536ead26b28465a224be5ffc918bf2a490102faa5a1bcf5"
mock_project.credentials.connection_info.return_value = "test"
set_from_args(Namespace(), {})
manifest = ManifestLoader(mock_project, {})
assert manifest.manifest.state_check.profile_hash.checksum == profile_hash
mock_project.credentials.connection_info.return_value = "test1"
Expand Down Expand Up @@ -67,7 +65,6 @@ def test_partial_parse_safe_update_project_parser_files_partially(
mock_saved_manifest.files = {}
patched_read_manifest_for_partial_parse.return_value = mock_saved_manifest

set_from_args(Namespace(), {})
loader = ManifestLoader(mock_project, {})
loader.safe_update_project_parser_files_partially({})

Expand Down Expand Up @@ -150,7 +147,6 @@ def test_partial_parse_file_diff_flag(
mock_file_diff = mocker.patch("dbt.parser.read_files.FileDiff.from_dict")
mock_file_diff.return_value = FileDiff([], [], [])

set_from_args(Namespace(), {})
ManifestLoader.get_full_manifest(config=mock_project)
assert not mock_file_diff.called

Expand Down
149 changes: 77 additions & 72 deletions tests/unit/test_compilation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import os
import tempfile
import unittest
from argparse import Namespace
from queue import Empty
from unittest import mock

from dbt import compilation
from dbt.flags import set_from_args
import pytest

from dbt.compilation import Graph, Linker
from dbt.graph.cli import parse_difference
from dbt.graph.queue import GraphQueue
from dbt.graph.selector import NodeSelector

set_from_args(Namespace(WARN_ERROR=False), None)


def _mock_manifest(nodes):
config = mock.MagicMock(enabled=True)
Expand All @@ -33,41 +31,48 @@ def _mock_manifest(nodes):
return manifest


class LinkerTest(unittest.TestCase):
def setUp(self):
self.linker = compilation.Linker()
class TestLinker:
@pytest.fixture
def linker(self) -> Linker:
return Linker()

def test_linker_add_node(self):
def test_linker_add_node(self, linker: Linker) -> None:
expected_nodes = ["A", "B", "C"]
for node in expected_nodes:
self.linker.add_node(node)
linker.add_node(node)

actual_nodes = self.linker.nodes()
actual_nodes = linker.nodes()
for node in expected_nodes:
self.assertIn(node, actual_nodes)
assert node in actual_nodes

self.assertEqual(len(actual_nodes), len(expected_nodes))
assert len(actual_nodes) == len(expected_nodes)

def test_linker_write_graph(self):
def test_linker_write_graph(self, linker: Linker) -> None:
expected_nodes = ["A", "B", "C"]
for node in expected_nodes:
self.linker.add_node(node)
linker.add_node(node)

manifest = _mock_manifest("ABC")
(fd, fname) = tempfile.mkstemp()
os.close(fd)
try:
self.linker.write_graph(fname, manifest)
linker.write_graph(fname, manifest)
assert os.path.exists(fname)
finally:
os.unlink(fname)

def assert_would_join(self, queue):
def assert_would_join(self, queue: GraphQueue) -> None:
"""test join() without timeout risk"""
self.assertEqual(queue.inner.unfinished_tasks, 0)

def _get_graph_queue(self, manifest, include=None, exclude=None):
graph = compilation.Graph(self.linker.graph)
assert queue.inner.unfinished_tasks == 0

def _get_graph_queue(
self,
manifest,
linker: Linker,
include=None,
exclude=None,
) -> GraphQueue:
graph = Graph(linker.graph)
selector = NodeSelector(graph, manifest)
# TODO: The "eager" string below needs to be replaced with programatic access
# to the default value for the indirect selection parameter in
Expand All @@ -77,114 +82,114 @@ def _get_graph_queue(self, manifest, include=None, exclude=None):
spec = parse_difference(include, exclude)
return selector.get_graph_queue(spec)

def test_linker_add_dependency(self):
def test_linker_add_dependency(self, linker: Linker) -> None:
actual_deps = [("A", "B"), ("A", "C"), ("B", "C")]

for (l, r) in actual_deps:
self.linker.dependency(l, r)
linker.dependency(l, r)

queue = self._get_graph_queue(_mock_manifest("ABC"))
queue = self._get_graph_queue(_mock_manifest("ABC"), linker)

got = queue.get(block=False)
self.assertEqual(got.unique_id, "C")
with self.assertRaises(Empty):
assert got.unique_id == "C"
with pytest.raises(Empty):
queue.get(block=False)
self.assertFalse(queue.empty())
assert not queue.empty()
queue.mark_done("C")
self.assertFalse(queue.empty())
assert not queue.empty()

got = queue.get(block=False)
self.assertEqual(got.unique_id, "B")
with self.assertRaises(Empty):
assert got.unique_id == "B"
with pytest.raises(Empty):
queue.get(block=False)
self.assertFalse(queue.empty())
assert not queue.empty()
queue.mark_done("B")
self.assertFalse(queue.empty())
assert not queue.empty()

got = queue.get(block=False)
self.assertEqual(got.unique_id, "A")
with self.assertRaises(Empty):
assert got.unique_id == "A"
with pytest.raises(Empty):
queue.get(block=False)
self.assertTrue(queue.empty())
assert queue.empty()
queue.mark_done("A")
self.assert_would_join(queue)
self.assertTrue(queue.empty())
assert queue.empty()

def test_linker_add_disjoint_dependencies(self):
def test_linker_add_disjoint_dependencies(self, linker: Linker) -> None:
actual_deps = [("A", "B")]
additional_node = "Z"

for (l, r) in actual_deps:
self.linker.dependency(l, r)
self.linker.add_node(additional_node)
linker.dependency(l, r)
linker.add_node(additional_node)

queue = self._get_graph_queue(_mock_manifest("ABCZ"))
queue = self._get_graph_queue(_mock_manifest("ABCZ"), linker)
# the first one we get must be B, it has the longest dep chain
first = queue.get(block=False)
self.assertEqual(first.unique_id, "B")
self.assertFalse(queue.empty())
assert first.unique_id == "B"
assert not queue.empty()
queue.mark_done("B")
self.assertFalse(queue.empty())
assert not queue.empty()

second = queue.get(block=False)
self.assertIn(second.unique_id, {"A", "Z"})
self.assertFalse(queue.empty())
assert second.unique_id in {"A", "Z"}
assert not queue.empty()
queue.mark_done(second.unique_id)
self.assertFalse(queue.empty())
assert not queue.empty()

third = queue.get(block=False)
self.assertIn(third.unique_id, {"A", "Z"})
with self.assertRaises(Empty):
assert third.unique_id in {"A", "Z"}
with pytest.raises(Empty):
queue.get(block=False)
self.assertNotEqual(second.unique_id, third.unique_id)
self.assertTrue(queue.empty())
assert second.unique_id != third.unique_id
assert queue.empty()
queue.mark_done(third.unique_id)
self.assert_would_join(queue)
self.assertTrue(queue.empty())
assert queue.empty()

def test_linker_dependencies_limited_to_some_nodes(self):
def test_linker_dependencies_limited_to_some_nodes(self, linker: Linker) -> None:
actual_deps = [("A", "B"), ("B", "C"), ("C", "D")]

for (l, r) in actual_deps:
self.linker.dependency(l, r)
linker.dependency(l, r)

queue = self._get_graph_queue(_mock_manifest("ABCD"), ["B"])
queue = self._get_graph_queue(_mock_manifest("ABCD"), linker, ["B"])
got = queue.get(block=False)
self.assertEqual(got.unique_id, "B")
self.assertTrue(queue.empty())
assert got.unique_id == "B"
assert queue.empty()
queue.mark_done("B")
self.assert_would_join(queue)

queue_2 = queue = self._get_graph_queue(_mock_manifest("ABCD"), ["A", "B"])
queue_2 = queue = self._get_graph_queue(_mock_manifest("ABCD"), linker, ["A", "B"])
got = queue_2.get(block=False)
self.assertEqual(got.unique_id, "B")
self.assertFalse(queue_2.empty())
with self.assertRaises(Empty):
assert got.unique_id == "B"
assert not queue_2.empty()
with pytest.raises(Empty):
queue_2.get(block=False)
queue_2.mark_done("B")
self.assertFalse(queue_2.empty())
assert not queue_2.empty()

got = queue_2.get(block=False)
self.assertEqual(got.unique_id, "A")
self.assertTrue(queue_2.empty())
with self.assertRaises(Empty):
assert got.unique_id == "A"
assert queue_2.empty()
with pytest.raises(Empty):
queue_2.get(block=False)
self.assertTrue(queue_2.empty())
assert queue_2.empty()
queue_2.mark_done("A")
self.assert_would_join(queue_2)

def test__find_cycles__cycles(self):
def test__find_cycles__cycles(self, linker: Linker) -> None:
actual_deps = [("A", "B"), ("B", "C"), ("C", "A")]

for (l, r) in actual_deps:
self.linker.dependency(l, r)
linker.dependency(l, r)

self.assertIsNotNone(self.linker.find_cycles())
assert linker.find_cycles() is not None

def test__find_cycles__no_cycles(self):
def test__find_cycles__no_cycles(self, linker: Linker) -> None:
actual_deps = [("A", "B"), ("B", "C"), ("C", "D")]

for (l, r) in actual_deps:
self.linker.dependency(l, r)
linker.dependency(l, r)

self.assertIsNone(self.linker.find_cycles())
assert linker.find_cycles() is None
6 changes: 4 additions & 2 deletions tests/unit/test_contracts_graph_parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from hypothesis import given
from hypothesis.strategies import builds, lists

from dbt import flags
from dbt.artifacts.resources import (
ColumnInfo,
Dimension,
Expand Down Expand Up @@ -67,7 +66,10 @@
replace_config,
)

flags.set_from_args(Namespace(SEND_ANONYMOUS_USAGE_STATS=False), None)

@pytest.fixture
def flags_for_args() -> Namespace:
return Namespace(SEND_ANONYMOUS_USAGE_STATS=False)


@pytest.fixture
Expand Down
4 changes: 0 additions & 4 deletions tests/unit/test_deprecations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from argparse import Namespace

from dbt.flags import set_from_args
from dbt.internal_deprecations import deprecated


Expand All @@ -11,6 +8,5 @@ def to_be_decorated():

# simple test that the return value is not modified
def test_deprecated_func():
set_from_args(Namespace(WARN_ERROR=False), None)
assert hasattr(to_be_decorated, "__wrapped__")
assert to_be_decorated() == 5
Loading

0 comments on commit 09243d1

Please sign in to comment.