diff --git a/core/dbt/constants.py b/core/dbt/constants.py index c5b949f83aa..1599df3e335 100644 --- a/core/dbt/constants.py +++ b/core/dbt/constants.py @@ -1,2 +1,3 @@ SECRET_ENV_PREFIX = "DBT_ENV_SECRET_" DEFAULT_ENV_PLACEHOLDER = "DBT_DEFAULT_PLACEHOLDER" +METADATA_ENV_PREFIX = "DBT_ENV_CUSTOM_ENV_" diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 7adb3651280..c053d28d1df 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -41,7 +41,7 @@ ParsedSourceDefinition, ) from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference -from dbt.contracts.util import get_metadata_env +from dbt.events.functions import get_metadata_vars from dbt.exceptions import ( CompilationException, ParsingException, @@ -713,7 +713,7 @@ def _get_namespace_builder(self): @contextproperty def dbt_metadata_envs(self) -> Dict[str, str]: - return get_metadata_env() + return get_metadata_vars() @contextproperty def invocation_args_dict(self): diff --git a/core/dbt/contracts/util.py b/core/dbt/contracts/util.py index bc6c32ab237..f0975fda10b 100644 --- a/core/dbt/contracts/util.py +++ b/core/dbt/contracts/util.py @@ -1,5 +1,4 @@ import dataclasses -import os from datetime import datetime from typing import List, Tuple, ClassVar, Type, TypeVar, Dict, Any, Optional @@ -11,7 +10,7 @@ IncompatibleSchemaException, ) from dbt.version import __version__ -from dbt.events.functions import get_invocation_id +from dbt.events.functions import get_invocation_id, get_metadata_vars from dbt.dataclass_schema import dbtClassMixin from dbt.dataclass_schema import ( @@ -148,20 +147,6 @@ def __str__(self) -> str: return BASE_SCHEMAS_URL + self.path -SCHEMA_VERSION_KEY = "dbt_schema_version" - - -METADATA_ENV_PREFIX = "DBT_ENV_CUSTOM_ENV_" - - -def get_metadata_env() -> Dict[str, str]: - return { - k[len(METADATA_ENV_PREFIX) :]: v - for k, v in os.environ.items() - if k.startswith(METADATA_ENV_PREFIX) - } - - # This is used in the ManifestMetadata, RunResultsMetadata, RunOperationResultMetadata, # FreshnessMetadata, and CatalogMetadata classes @dataclasses.dataclass @@ -170,7 +155,7 @@ class BaseArtifactMetadata(dbtClassMixin): dbt_version: str = __version__ generated_at: datetime = dataclasses.field(default_factory=datetime.utcnow) invocation_id: Optional[str] = dataclasses.field(default_factory=get_invocation_id) - env: Dict[str, str] = dataclasses.field(default_factory=get_metadata_env) + env: Dict[str, str] = dataclasses.field(default_factory=get_metadata_vars) def __post_serialize__(self, dct): dct = super().__post_serialize__(dct) diff --git a/core/dbt/events/base_types.py b/core/dbt/events/base_types.py index 5ed447a48cc..489b70cb1ad 100644 --- a/core/dbt/events/base_types.py +++ b/core/dbt/events/base_types.py @@ -3,7 +3,6 @@ import threading from datetime import datetime - # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # These base types define the _required structure_ for the concrete event # # types defined in types.py # @@ -15,6 +14,12 @@ class Cache: pass +def get_global_metadata_vars() -> dict: + from dbt.events.functions import get_metadata_vars + + return get_metadata_vars() + + def get_invocation_id() -> str: from dbt.events.functions import get_invocation_id @@ -48,6 +53,7 @@ def __post_init__(self): if not hasattr(self.info, "msg") or not self.info.msg: self.info.msg = self.message() self.info.invocation_id = get_invocation_id() + self.info.extra = get_global_metadata_vars() self.info.ts = datetime.utcnow() self.info.pid = get_pid() self.info.thread = get_thread_name() diff --git a/core/dbt/events/functions.py b/core/dbt/events/functions.py index 51799b3fed7..3276ec467ed 100644 --- a/core/dbt/events/functions.py +++ b/core/dbt/events/functions.py @@ -3,7 +3,7 @@ from dbt.events.base_types import NoStdOut, BaseEvent, NoFile, Cache from dbt.events.types import EventBufferFull, MainReportVersion, EmptyLine import dbt.flags as flags -from dbt.constants import SECRET_ENV_PREFIX +from dbt.constants import SECRET_ENV_PREFIX, METADATA_ENV_PREFIX from dbt.logger import make_log_dir_if_missing, GLOBAL_LOGGER from datetime import datetime @@ -18,7 +18,7 @@ import os import uuid import threading -from typing import List, Optional, Union, Callable +from typing import List, Optional, Union, Callable, Dict from collections import deque LOG_VERSION = 3 @@ -40,6 +40,7 @@ format_color = True format_json = False invocation_id: Optional[str] = None +metadata_vars: Optional[Dict[str, str]] = None def setup_event_logger(log_path, level_override=None): @@ -267,6 +268,22 @@ def fire_event(e: BaseEvent) -> None: send_to_logger(STDOUT_LOG, level_tag=e.level_tag(), log_line=log_line) +def get_metadata_vars() -> Dict[str, str]: + global metadata_vars + if metadata_vars is None: + metadata_vars = { + k[len(METADATA_ENV_PREFIX) :]: v + for k, v in os.environ.items() + if k.startswith(METADATA_ENV_PREFIX) + } + return metadata_vars + + +def reset_metadata_vars() -> None: + global metadata_vars + metadata_vars = None + + def get_invocation_id() -> str: global invocation_id if invocation_id is None: diff --git a/core/dbt/lib.py b/core/dbt/lib.py index 11d4e07e524..ff8f06c88a8 100644 --- a/core/dbt/lib.py +++ b/core/dbt/lib.py @@ -90,6 +90,7 @@ def get_dbt_config(project_dir, args=None, single_threaded=False): # Make sure we have a valid invocation_id dbt.events.functions.set_invocation_id() + dbt.events.functions.reset_metadata_vars() return config diff --git a/core/dbt/tests/util.py b/core/dbt/tests/util.py index b4fb24fdd50..af837c18b17 100644 --- a/core/dbt/tests/util.py +++ b/core/dbt/tests/util.py @@ -11,7 +11,7 @@ from dbt.main import handle_and_check from dbt.logger import log_manager from dbt.contracts.graph.manifest import Manifest -from dbt.events.functions import fire_event, capture_stdout_logs, stop_capture_stdout_logs +from dbt.events.functions import fire_event, capture_stdout_logs, stop_capture_stdout_logs, reset_metadata_vars from dbt.events.test_types import IntegrationTestDebug # ============================================================================= @@ -63,6 +63,9 @@ def run_dbt(args: List[str] = None, expect_pass=True): # Ignore logbook warnings warnings.filterwarnings("ignore", category=DeprecationWarning, module="logbook") + # reset global vars + reset_metadata_vars() + # The logger will complain about already being initialized if # we don't do this. log_manager.reset_handlers() diff --git a/test/unit/test_context.py b/test/unit/test_context.py index e1737103e1f..668d76cc525 100644 --- a/test/unit/test_context.py +++ b/test/unit/test_context.py @@ -19,6 +19,7 @@ from dbt.config.project import VarProvider from dbt.context import base, target, configured, providers, docs, manifest, macros from dbt.contracts.files import FileHash +from dbt.events.functions import reset_metadata_vars from dbt.node_types import NodeType import dbt.exceptions from .utils import ( @@ -503,6 +504,8 @@ def test_macro_namespace(config_postgres, manifest_fx): assert result["some_macro"].macro is package_macro def test_dbt_metadata_envs(monkeypatch, config_postgres, manifest_fx, get_adapter, get_include_paths): + reset_metadata_vars() + envs = { "DBT_ENV_CUSTOM_ENV_RUN_ID": 1234, "DBT_ENV_CUSTOM_ENV_JOB_ID": 5678, @@ -519,3 +522,6 @@ def test_dbt_metadata_envs(monkeypatch, config_postgres, manifest_fx, get_adapte ) assert ctx["dbt_metadata_envs"] == {'JOB_ID': 5678, 'RUN_ID': 1234} + + # cleanup + reset_metadata_vars() diff --git a/test/unit/test_manifest.py b/test/unit/test_manifest.py index e6a83ac0759..cbce93fc052 100644 --- a/test/unit/test_manifest.py +++ b/test/unit/test_manifest.py @@ -34,7 +34,8 @@ ) from dbt.contracts.graph.compiled import CompiledModelNode -from dbt.events.functions import get_invocation_id +from dbt.events.functions import reset_metadata_vars + from dbt.node_types import NodeType import freezegun @@ -60,6 +61,8 @@ class ManifestTest(unittest.TestCase): def setUp(self): + reset_metadata_vars() + # TODO: why is this needed for tests in this module to pass? tracking.active_user = None @@ -304,6 +307,7 @@ def setUp(self): def tearDown(self): del os.environ['DBT_ENV_CUSTOM_ENV_key'] + reset_metadata_vars() @freezegun.freeze_time('2018-02-14T09:15:13Z') def test__no_nodes(self): diff --git a/tests/functional/context_methods/test_custom_env_vars.py b/tests/functional/context_methods/test_custom_env_vars.py new file mode 100644 index 00000000000..413789c7676 --- /dev/null +++ b/tests/functional/context_methods/test_custom_env_vars.py @@ -0,0 +1,33 @@ +import pytest +import json +import os + +from dbt.tests.util import run_dbt_and_capture + + +def parse_json_logs(json_log_output): + parsed_logs = [] + for line in json_log_output.split("\n"): + try: + log = json.loads(line) + except ValueError: + continue + + parsed_logs.append(log) + + return parsed_logs + + +class TestCustomVarInLogs: + @pytest.fixture(scope="class", autouse=True) + def setup(self): + # on windows, python uppercases env var names because windows is case insensitive + os.environ["DBT_ENV_CUSTOM_ENV_SOME_VAR"] = "value" + yield + del os.environ["DBT_ENV_CUSTOM_ENV_SOME_VAR"] + + def test_extra_filled(self, project): + _, log_output = run_dbt_and_capture(['--log-format=json', 'deps'],) + logs = parse_json_logs(log_output) + for log in logs: + assert log['info'].get('extra') == {"SOME_VAR": "value"} diff --git a/tests/unit/test_proto_events.py b/tests/unit/test_proto_events.py index 34f6ab5d019..46e9479ef39 100644 --- a/tests/unit/test_proto_events.py +++ b/tests/unit/test_proto_events.py @@ -7,7 +7,7 @@ PluginLoadError, PrintStartLine, ) -from dbt.events.functions import event_to_dict, LOG_VERSION +from dbt.events.functions import event_to_dict, LOG_VERSION, reset_metadata_vars from dbt.events import proto_types as pl from dbt.version import installed @@ -97,3 +97,27 @@ def test_node_info_events(): ) assert event assert event.node_info.node_path == "some_path" + + +def test_extra_dict_on_event(monkeypatch): + + monkeypatch.setenv("DBT_ENV_CUSTOM_ENV_env_key", "env_value") + + reset_metadata_vars() + + event = MainReportVersion(version=str(installed), log_version=LOG_VERSION) + event_dict = event_to_dict(event) + assert set(event_dict["info"].keys()) == info_keys + assert event.info.extra == {"env_key": "env_value"} + serialized = bytes(event) + + # Extract EventInfo from serialized message + generic_event = pl.GenericMessage().parse(serialized) + assert generic_event.info.code == "A001" + # get the message class for the real message from the generic message + message_class = getattr(sys.modules["dbt.events.proto_types"], generic_event.info.name) + new_event = message_class().parse(serialized) + assert new_event.info.extra == event.info.extra + + # clean up + reset_metadata_vars()