Skip to content

Commit

Permalink
Try to make imports a little more sane, ordering-wise
Browse files Browse the repository at this point in the history
consolidate dbt.ui, move non-rpc node_runners into their tasks
move parse_cli_vars into config.utils
get rid of logger/exceptions requirements in dbt.utils
  • Loading branch information
Jacob Beck committed Jun 25, 2020
1 parent 62a0bf8 commit 32c5598
Show file tree
Hide file tree
Showing 46 changed files with 1,033 additions and 1,009 deletions.
8 changes: 4 additions & 4 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import agate

import dbt.exceptions
import dbt.flags
from dbt.contracts.connection import (
Connection, Identifier, ConnectionState, AdapterRequiredConfig, LazyHandle
)
Expand All @@ -19,6 +18,7 @@
MacroQueryStringSetter,
)
from dbt.logger import GLOBAL_LOGGER as logger
from dbt import flags


class BaseConnectionManager(metaclass=abc.ABCMeta):
Expand All @@ -39,7 +39,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
def __init__(self, profile: AdapterRequiredConfig):
self.profile = profile
self.thread_connections: Dict[Hashable, Connection] = {}
self.lock: RLock = dbt.flags.MP_CONTEXT.RLock()
self.lock: RLock = flags.MP_CONTEXT.RLock()
self.query_header: Optional[MacroQueryStringSetter] = None

def set_query_header(self, manifest: Manifest) -> None:
Expand Down Expand Up @@ -235,7 +235,7 @@ def _close_handle(cls, connection: Connection) -> None:
@classmethod
def _rollback(cls, connection: Connection) -> None:
"""Roll back the given connection."""
if dbt.flags.STRICT_MODE:
if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In _rollback, got {connection} - not a Connection!'
Expand All @@ -253,7 +253,7 @@ def _rollback(cls, connection: Connection) -> None:

@classmethod
def close(cls, connection: Connection) -> Connection:
if dbt.flags.STRICT_MODE:
if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In close, got {connection} - not a Connection!'
Expand Down
14 changes: 7 additions & 7 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
get_relation_returned_multiple_results,
InternalException, NotImplementedException, RuntimeException,
)
import dbt.flags
from dbt import flags

from dbt import deprecations
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
Expand Down Expand Up @@ -267,7 +267,7 @@ def load_internal_manifest(self) -> Manifest:
def _schema_is_cached(self, database: Optional[str], schema: str) -> bool:
"""Check if the schema is cached, and by default logs if it is not."""

if dbt.flags.USE_CACHE is False:
if flags.USE_CACHE is False:
return False
elif (database, schema) not in self.cache:
logger.debug(
Expand Down Expand Up @@ -323,7 +323,7 @@ def _relations_cache_for_schemas(self, manifest: Manifest) -> None:
"""Populate the relations cache for the given schemas. Returns an
iterable of the schemas populated, as strings.
"""
if not dbt.flags.USE_CACHE:
if not flags.USE_CACHE:
return

cache_schemas = self._get_cache_schemas(manifest)
Expand Down Expand Up @@ -352,7 +352,7 @@ def set_relations_cache(
"""Run a query that gets a populated cache of the relations in the
database and set the cache on this adapter.
"""
if not dbt.flags.USE_CACHE:
if not flags.USE_CACHE:
return

with self.cache.lock:
Expand All @@ -368,7 +368,7 @@ def cache_added(self, relation: Optional[BaseRelation]) -> str:
raise_compiler_error(
'Attempted to cache a null relation for {}'.format(name)
)
if dbt.flags.USE_CACHE:
if flags.USE_CACHE:
self.cache.add(relation)
# so jinja doesn't render things
return ''
Expand All @@ -383,7 +383,7 @@ def cache_dropped(self, relation: Optional[BaseRelation]) -> str:
raise_compiler_error(
'Attempted to drop a null relation for {}'.format(name)
)
if dbt.flags.USE_CACHE:
if flags.USE_CACHE:
self.cache.drop(relation)
return ''

Expand All @@ -405,7 +405,7 @@ def cache_renamed(
.format(src_name, dst_name, name)
)

if dbt.flags.USE_CACHE:
if flags.USE_CACHE:
self.cache.rename(from_relation, to_relation)
return ''

Expand Down
7 changes: 4 additions & 3 deletions core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

import dbt.clients.agate_helper
import dbt.exceptions
from dbt.contracts.connection import Connection
from dbt.adapters.base import BaseConnectionManager
from dbt.contracts.connection import Connection
from dbt.logger import GLOBAL_LOGGER as logger
from dbt import flags


class SQLConnectionManager(BaseConnectionManager):
Expand Down Expand Up @@ -133,7 +134,7 @@ def add_commit_query(self):
def begin(self):
connection = self.get_thread_connection()

if dbt.flags.STRICT_MODE:
if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In begin, got {connection} - not a Connection!'
Expand All @@ -151,7 +152,7 @@ def begin(self):

def commit(self):
connection = self.get_thread_connection()
if dbt.flags.STRICT_MODE:
if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In commit, got {connection} - not a Connection!'
Expand Down
1 change: 0 additions & 1 deletion core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import dbt.clients.agate_helper
from dbt.contracts.connection import Connection
import dbt.exceptions
import dbt.flags
from dbt.adapters.base import BaseAdapter, available
from dbt.adapters.sql import SQLConnectionManager
from dbt.logger import GLOBAL_LOGGER as logger
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
InternalException, raise_compiler_error, CompilationException,
invalid_materialization_argument, MacroReturn
)
from dbt.flags import MACRO_DEBUGGING
from dbt import flags
from dbt.logger import GLOBAL_LOGGER as logger # noqa


Expand Down Expand Up @@ -93,8 +93,8 @@ def _compile(self, source, filename):
If the value is 'write', also write the files to disk.
WARNING: This can write a ton of data if you aren't careful.
"""
if filename == '<template>' and MACRO_DEBUGGING:
write = MACRO_DEBUGGING == 'write'
if filename == '<template>' and flags.MACRO_DEBUGGING:
write = flags.MACRO_DEBUGGING == 'write'
filename = _linecache_inject(source, write)

return super()._compile(source, filename) # type: ignore
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import dbt.include
import dbt.tracking

from dbt import flags
from dbt.node_types import NodeType
from dbt.linker import Linker

from dbt.context.providers import generate_runtime_model
from dbt.contracts.graph.compiled import NonSourceNode
from dbt.contracts.graph.manifest import Manifest
import dbt.exceptions
import dbt.flags
import dbt.config
from dbt.contracts.graph.compiled import (
InjectedCTE,
Expand Down Expand Up @@ -104,7 +104,7 @@ def recursively_prepend_ctes(model, manifest):
if model.extra_ctes_injected:
return (model, model.extra_ctes, manifest)

if dbt.flags.STRICT_MODE:
if flags.STRICT_MODE:
if not isinstance(model, tuple(COMPILED_TYPES.values())):
raise dbt.exceptions.InternalException(
'Bad model type: {}'.format(type(model))
Expand Down Expand Up @@ -187,7 +187,7 @@ def compile_node(self, node, manifest, extra_context=None):
def write_graph_file(self, linker: Linker, manifest: Manifest):
filename = graph_file_name
graph_path = os.path.join(self.config.target_path, filename)
if dbt.flags.WRITE_JSON:
if flags.WRITE_JSON:
linker.write_graph(graph_path, manifest)

def link_node(
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .profile import Profile
from .project import Project
from .renderer import DbtProjectYamlRenderer, ProfileRenderer
from .utils import parse_cli_vars
from dbt import tracking
from dbt.adapters.factory import get_relation_class_by_name, get_include_paths
from dbt.helper_types import FQNPath, PathSet
Expand All @@ -19,8 +20,7 @@
from dbt.contracts.connection import AdapterRequiredConfig, Credentials
from dbt.contracts.graph.manifest import ManifestMetadata
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.ui import printer
from dbt.utils import parse_cli_vars
from dbt.ui import warning_tag

from dbt.contracts.project import Configuration, UserConfig
from dbt.exceptions import (
Expand Down Expand Up @@ -317,7 +317,7 @@ def warn_for_unused_resource_config_paths(
'\n'.join('- {}'.format('.'.join(u)) for u in unused)
)

warn_or_error(msg, log_fmt=printer.warning_tag('{}'))
warn_or_error(msg, log_fmt=warning_tag('{}'))

def load_dependencies(self) -> Mapping[str, 'RuntimeConfig']:
if self.dependencies is None:
Expand Down
23 changes: 23 additions & 0 deletions core/dbt/config/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Dict, Any

from dbt.clients import yaml_helper
from dbt.exceptions import raise_compiler_error, ValidationException
from dbt.logger import GLOBAL_LOGGER as logger


def parse_cli_vars(var_string: str) -> Dict[str, Any]:
try:
cli_vars = yaml_helper.load_yaml_text(var_string)
var_type = type(cli_vars)
if var_type is dict:
return cli_vars
else:
type_name = var_type.__name__
raise_compiler_error(
"The --vars argument must be a YAML dictionary, but was "
"of type '{}'".format(type_name))
except ValidationException:
logger.error(
"The YAML provided in the --vars argument is not valid.\n"
)
raise
2 changes: 1 addition & 1 deletion core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from hologram import JsonSchemaMixin, ValidationError
from hologram.helpers import StrEnum, register_pattern

from dbt import hooks
from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed
from dbt.exceptions import CompilationException, InternalException
from dbt.contracts.util import Replaceable, list_str
from dbt import hooks
from dbt.node_types import NodeType


Expand Down
6 changes: 3 additions & 3 deletions core/dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from hologram import JsonSchemaMixin

from dbt.clients.system import write_file
import dbt.flags
from dbt.contracts.graph.unparsed import (
UnparsedNode, UnparsedDocumentation, Quoting, Docs,
UnparsedBaseNode, FreshnessThreshold, ExternalTable,
Expand All @@ -24,6 +23,7 @@
)
from dbt.contracts.util import Replaceable
from dbt.logger import GLOBAL_LOGGER as logger # noqa
from dbt import flags
from dbt.node_types import NodeType


Expand Down Expand Up @@ -117,7 +117,7 @@ def patch(self, patch: 'ParsedNodePatch'):
self.columns = patch.columns
self.meta = patch.meta
self.docs = patch.docs
if dbt.flags.STRICT_MODE:
if flags.STRICT_MODE:
assert isinstance(self, JsonSchemaMixin)
self.to_dict(validate=True)

Expand Down Expand Up @@ -300,7 +300,7 @@ def patch(self, patch: ParsedMacroPatch):
self.meta = patch.meta
self.docs = patch.docs
self.arguments = patch.arguments
if dbt.flags.STRICT_MODE:
if flags.STRICT_MODE:
assert isinstance(self, JsonSchemaMixin)
self.to_dict(validate=True)

Expand Down
6 changes: 3 additions & 3 deletions core/dbt/contracts/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dbt.helper_types import NoValue
from dbt.logger import GLOBAL_LOGGER as logger # noqa
from dbt import tracking
from dbt.ui import printer
from dbt import ui

from hologram import JsonSchemaMixin, ValidationError
from hologram.helpers import HyphenatedJsonSchemaMixin, register_pattern, \
Expand Down Expand Up @@ -268,10 +268,10 @@ def set_values(self, cookie_dir):
tracking.do_not_track()

if self.use_colors:
printer.use_colors()
ui.use_colors()

if self.printer_width:
printer.printer_width(self.printer_width)
ui.printer_width(self.printer_width)


@dataclass
Expand Down
6 changes: 2 additions & 4 deletions core/dbt/deprecations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Optional, Set, List, Dict, ClassVar

import dbt.links
import dbt.exceptions
import dbt.flags
from dbt.ui import printer
from dbt import ui


class DBTDeprecation:
Expand All @@ -29,7 +27,7 @@ def description(self) -> str:
def show(self, *args, **kwargs) -> None:
if self.name not in active_deprecations:
desc = self.description.format(**kwargs)
msg = printer.line_wrap_message(
msg = ui.line_wrap_message(
desc, prefix='* Deprecation Warning: '
)
dbt.exceptions.warn_or_error(msg)
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/deps/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ExecutableError, warn_or_error, raise_dependency_error
)
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.ui import printer
from dbt import ui

PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa

Expand Down Expand Up @@ -77,7 +77,7 @@ def _fetch_metadata(self, project, renderer) -> ProjectPackageMetadata:
'The git package "{}" is not pinned.\n\tThis can introduce '
'breaking changes into your project without warning!\n\nSee {}'
.format(self.git, PIN_PACKAGE_URL),
log_fmt=printer.yellow('WARNING: {}')
log_fmt=ui.yellow('WARNING: {}')
)
loaded = Project.from_project_root(path, renderer)
return ProjectPackageMetadata.from_project(loaded)
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from dbt.logger import GLOBAL_LOGGER as logger
from dbt.node_types import NodeType
import dbt.flags
from dbt import flags
from dbt.ui import line_wrap_message

import hologram

Expand Down Expand Up @@ -867,7 +868,6 @@ def raise_unrecognized_credentials_type(typename, supported_types):
def raise_invalid_patch(
node, patch_section: str, patch_path: str,
) -> NoReturn:
from dbt.ui.printer import line_wrap_message
msg = line_wrap_message(
f'''\
'{node.name}' is a {node.resource_type} node, but it is
Expand Down Expand Up @@ -904,7 +904,7 @@ def raise_duplicate_alias(


def warn_or_error(msg, node=None, log_fmt=None):
if dbt.flags.WARN_ERROR:
if flags.WARN_ERROR:
raise_compiler_error(msg, node)
else:
if log_fmt is not None:
Expand All @@ -913,7 +913,7 @@ def warn_or_error(msg, node=None, log_fmt=None):


def warn_or_raise(exc, log_fmt=None):
if dbt.flags.WARN_ERROR:
if flags.WARN_ERROR:
raise exc
else:
msg = str(exc)
Expand Down
Loading

0 comments on commit 32c5598

Please sign in to comment.