From 9cc7a7a87fcfdc2c558b91c9316efb22bf3fdb27 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 17 Dec 2019 11:13:33 -0700 Subject: [PATCH] Fix mypy checking Make mypy check our nested namespace packages by putting dbt in the mypy_path. Fix a number of exposed mypy/type checker complaints. The checker mostly passes now even if you add `--check-untyped-defs`, though there are a couple lingering issues so I'll leave that out of CI Change the return type of RunOperation a bit - adds a couple fields to appease mypy Also, bump the mypy version (it catches a few more issues). --- core/dbt/adapters/base/connections.py | 5 +- core/dbt/adapters/base/impl.py | 23 +- core/dbt/adapters/base/meta.py | 22 +- core/dbt/adapters/base/query_headers.py | 16 +- core/dbt/adapters/base/relation.py | 9 +- core/dbt/adapters/cache.py | 40 ++-- core/dbt/adapters/factory.py | 5 +- core/dbt/adapters/sql/impl.py | 4 +- core/dbt/clients/_jinja_blocks.py | 1 + core/dbt/clients/git.py | 4 + core/dbt/clients/jinja.py | 49 ++-- core/dbt/clients/system.py | 79 +++--- core/dbt/compilation.py | 17 +- core/dbt/config/profile.py | 7 +- core/dbt/context/base.py | 6 +- core/dbt/context/common.py | 4 +- core/dbt/context/runtime.py | 2 +- core/dbt/contracts/connection.py | 30 +-- core/dbt/contracts/graph/compiled.py | 3 +- core/dbt/contracts/graph/manifest.py | 44 ++-- core/dbt/contracts/graph/parsed.py | 20 +- core/dbt/contracts/results.py | 44 +++- core/dbt/contracts/rpc.py | 9 +- core/dbt/deps/resolver.py | 4 +- core/dbt/exceptions.py | 5 +- core/dbt/graph/selector.py | 65 ++--- core/dbt/helper_types.py | 8 - core/dbt/linker.py | 2 +- core/dbt/logger.py | 9 +- core/dbt/main.py | 8 +- core/dbt/node_runners.py | 24 +- core/dbt/parser/manifest.py | 4 +- core/dbt/rpc/method.py | 5 + core/dbt/rpc/task_handler.py | 27 ++- core/dbt/rpc/task_manager.py | 6 - core/dbt/semver.py | 225 +++++++++--------- core/dbt/source_config.py | 10 +- core/dbt/task/debug.py | 20 +- core/dbt/task/generate.py | 10 + core/dbt/task/list.py | 12 +- core/dbt/task/rpc/base.py | 4 +- core/dbt/task/rpc/cli.py | 10 +- core/dbt/task/rpc/project_commands.py | 29 ++- core/dbt/task/rpc/server.py | 8 +- core/dbt/task/rpc/sql_commands.py | 19 +- core/dbt/task/run.py | 29 ++- core/dbt/task/run_operation.py | 36 ++- core/dbt/task/runnable.py | 71 +++++- core/dbt/task/serve.py | 7 +- core/dbt/tracking.py | 29 ++- core/dbt/utils.py | 46 ++-- dev_requirements.txt | 2 +- mypy.ini | 2 +- .../macros/test_int_inference.sql | 25 +- .../test_bigquery_query_results.py | 22 +- test/unit/test_semver.py | 2 +- third-party-stubs/agate/__init__.pyi | 16 +- third-party-stubs/jsonrpc/manager.pyi | 4 +- .../snowplow_tracker/__init__.pyi | 6 + 59 files changed, 788 insertions(+), 466 deletions(-) diff --git a/core/dbt/adapters/base/connections.py b/core/dbt/adapters/base/connections.py index ea2f39c97f7..7462c469dbd 100644 --- a/core/dbt/adapters/base/connections.py +++ b/core/dbt/adapters/base/connections.py @@ -1,6 +1,7 @@ import abc import os -from multiprocessing import RLock +# multiprocessing.RLock is a function returning this type +from multiprocessing.synchronize import RLock from threading import get_ident from typing import ( Dict, Tuple, Hashable, Optional, ContextManager, List @@ -144,7 +145,7 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection: 'Opening a new connection, currently in state {}' .format(conn.state) ) - conn.handle = LazyHandle(type(self)) + conn.handle = LazyHandle(self.open) conn.name = conn_name return conn diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index cbc06311553..08258806f16 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import ( Optional, Tuple, Callable, Container, FrozenSet, Type, Dict, Any, List, - Mapping, Iterator, Union + Mapping, Iterator, Union, Set ) import agate @@ -125,7 +125,10 @@ def add(self, relation: BaseRelation): key = relation.information_schema_only() if key not in self: self[key] = set() - self[key].add(relation.schema.lower()) + lowered: Optional[str] = None + if relation.schema is not None: + lowered = relation.schema.lower() + self[key].add(lowered) def search(self): for information_schema_name, schemas in self.items(): @@ -133,7 +136,7 @@ def search(self): yield information_schema_name, schema def schemas_searched(self): - result = set() + result: Set[Tuple[str, str]] = set() for information_schema_name, schemas in self.items(): result.update( (information_schema_name.database, schema) @@ -907,13 +910,17 @@ def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str: @available @classmethod - def convert_type(cls, agate_table, col_idx): + def convert_type( + cls, agate_table: agate.Table, col_idx: int + ) -> Optional[str]: return cls.convert_agate_type(agate_table, col_idx) @classmethod - def convert_agate_type(cls, agate_table, col_idx): - agate_type = agate_table.column_types[col_idx] - conversions = [ + def convert_agate_type( + cls, agate_table: agate.Table, col_idx: int + ) -> Optional[str]: + agate_type: Type = agate_table.column_types[col_idx] + conversions: List[Tuple[Type, Callable[..., str]]] = [ (agate.Text, cls.convert_text_type), (agate.Number, cls.convert_number_type), (agate.Boolean, cls.convert_boolean_type), @@ -925,6 +932,8 @@ def convert_agate_type(cls, agate_table, col_idx): if isinstance(agate_type, agate_cls): return func(agate_table, col_idx) + return None + ### # Operations involving the manifest ### diff --git a/core/dbt/adapters/base/meta.py b/core/dbt/adapters/base/meta.py index 5858a919ab1..209240c0de7 100644 --- a/core/dbt/adapters/base/meta.py +++ b/core/dbt/adapters/base/meta.py @@ -1,6 +1,6 @@ import abc from functools import wraps -from typing import Callable, Optional, Any, FrozenSet, Dict +from typing import Callable, Optional, Any, FrozenSet, Dict, Set from dbt.deprecations import warn, renamed_method @@ -86,16 +86,26 @@ def parse_list(self, func: Callable) -> Callable: class AdapterMeta(abc.ABCMeta): + _available_: FrozenSet[str] + _parse_replacements_: Dict[str, Callable] + def __new__(mcls, name, bases, namespace, **kwargs): - cls = super().__new__(mcls, name, bases, namespace, **kwargs) + # mypy does not like the `**kwargs`. But `ABCMeta` itself takes + # `**kwargs` in its argspec here (and passes them to `type.__new__`. + # I'm not sure there is any benefit to it after poking around a bit, + # but having it doesn't hurt on the python side (and omitting it could + # hurt for obscure metaclass reasons, for all I know) + cls = abc.ABCMeta.__new__( # type: ignore + mcls, name, bases, namespace, **kwargs + ) # this is very much inspired by ABCMeta's own implementation # dict mapping the method name to whether the model name should be # injected into the arguments. All methods in here are exposed to the # context. - available = set() - replacements = {} + available: Set[str] = set() + replacements: Dict[str, Any] = {} # collect base class data first for base in bases: @@ -110,7 +120,7 @@ def __new__(mcls, name, bases, namespace, **kwargs): if parse_replacement is not None: replacements[name] = parse_replacement - cls._available_: FrozenSet[str] = frozenset(available) + cls._available_ = frozenset(available) # should this be a namedtuple so it will be immutable like _available_? - cls._parse_replacements_: Dict[str, Callable] = replacements + cls._parse_replacements_ = replacements return cls diff --git a/core/dbt/adapters/base/query_headers.py b/core/dbt/adapters/base/query_headers.py index 57bccbcd02c..b7a6c8fdd40 100644 --- a/core/dbt/adapters/base/query_headers.py +++ b/core/dbt/adapters/base/query_headers.py @@ -1,5 +1,5 @@ from threading import local -from typing import Optional, Callable +from typing import Optional, Callable, Dict, Any from dbt.clients.jinja import QueryStringGenerator @@ -56,7 +56,7 @@ def add(self, sql: str) -> str: return '/* {} */\n{}'.format(self.query_comment.strip(), sql) def set(self, comment: Optional[str]): - if '*/' in comment: + if isinstance(comment, str) and '*/' in comment: # tell the user "no" so they don't hurt themselves by writing # garbage raise RuntimeException( @@ -65,7 +65,7 @@ def set(self, comment: Optional[str]): self.query_comment = comment -QueryStringFunc = Callable[[str, Optional[CompileResultNode]], str] +QueryStringFunc = Callable[[str, Optional[NodeWrapper]], str] class QueryStringSetter: @@ -77,13 +77,14 @@ def __init__(self, config: AdapterRequiredConfig): self.generator: QueryStringFunc = lambda name, model: '' # if the comment value was None or the empty string, just skip it if comment_macro: + assert isinstance(comment_macro, str) macro = '\n'.join(( '{%- macro query_comment_macro(connection_name, node) -%}', - self._get_comment_macro(), + comment_macro, '{% endmacro %}' )) ctx = self._get_context() - self.generator: QueryStringFunc = QueryStringGenerator(macro, ctx) + self.generator = QueryStringGenerator(macro, ctx) self.comment = _QueryComment(None) self.reset() @@ -105,10 +106,9 @@ def reset(self): self.set('master', None) def set(self, name: str, node: Optional[CompileResultNode]): + wrapped: Optional[NodeWrapper] = None if node is not None: wrapped = NodeWrapper(node) - else: - wrapped = None comment_str = self.generator(name, wrapped) self.comment.set(comment_str) @@ -127,5 +127,5 @@ def _get_comment_macro(self): else: return super()._get_comment_macro() - def _get_context(self): + def _get_context(self) -> Dict[str, Any]: return QueryHeaderContext(self.config).to_dict(self.manifest.macros) diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index 957ac380274..611e1b8902a 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -422,10 +422,13 @@ def External(cls) -> str: return str(RelationType.External) @classproperty - def RelationType(cls) -> Type[RelationType]: + def get_relation_type(cls) -> Type[RelationType]: return RelationType +Info = TypeVar('Info', bound='InformationSchema') + + @dataclass(frozen=True, eq=False, repr=False) class InformationSchema(BaseRelation): information_schema_view: Optional[str] = None @@ -470,10 +473,10 @@ def get_quote_policy( @classmethod def from_relation( - cls: Self, + cls: Type[Info], relation: BaseRelation, information_schema_view: Optional[str], - ) -> Self: + ) -> Info: include_policy = cls.get_include_policy( relation, information_schema_view ) diff --git a/core/dbt/adapters/cache.py b/core/dbt/adapters/cache.py index 973ecd8f76d..157684fe9c2 100644 --- a/core/dbt/adapters/cache.py +++ b/core/dbt/adapters/cache.py @@ -1,6 +1,6 @@ from collections import namedtuple from copy import deepcopy -from typing import List, Iterable, Optional +from typing import List, Iterable, Optional, Dict, Set, Tuple, Any import threading from dbt.logger import CACHE_LOGGER as logger @@ -177,12 +177,14 @@ class RelationsCache: The adapters also hold this lock while filling the cache. :attr Set[str] schemas: The set of known/cached schemas, all lowercased. """ - def __init__(self): - self.relations = {} + def __init__(self) -> None: + self.relations: Dict[_ReferenceKey, _CachedRelation] = {} self.lock = threading.RLock() - self.schemas = set() + self.schemas: Set[Tuple[Optional[str], Optional[str]]] = set() - def add_schema(self, database: str, schema: str): + def add_schema( + self, database: Optional[str], schema: Optional[str], + ) -> None: """Add a schema to the set of known schemas (case-insensitive) :param database: The database name to add. @@ -190,7 +192,9 @@ def add_schema(self, database: str, schema: str): """ self.schemas.add((_lower(database), _lower(schema))) - def drop_schema(self, database: str, schema: str): + def drop_schema( + self, database: Optional[str], schema: Optional[str], + ) -> None: """Drop the given schema and remove it from the set of known schemas. Then remove all its contents (and their dependents, etc) as well. @@ -208,21 +212,21 @@ def drop_schema(self, database: str, schema: str): # handle a drop_schema race by using discard() over remove() self.schemas.discard(key) - def update_schemas(self, schemas: Iterable[str]): + def update_schemas(self, schemas: Iterable[Tuple[Optional[str], str]]): """Add multiple schemas to the set of known schemas (case-insensitive) :param schemas: An iterable of the schema names to add. """ - self.schemas.update((_lower(d), _lower(s)) for (d, s) in schemas) + self.schemas.update((_lower(d), s.lower()) for (d, s) in schemas) - def __contains__(self, schema_id): + def __contains__(self, schema_id: Tuple[Optional[str], str]): """A schema is 'in' the relations cache if it is in the set of cached schemas. - :param Tuple[str, str] schema: The db name and schema name to look up. + :param schema_id: The db name and schema name to look up. """ db, schema = schema_id - return (_lower(db), _lower(schema)) in self.schemas + return (_lower(db), schema.lower()) in self.schemas def dump_graph(self): """Dump a key-only representation of the schema to a dictionary. Every @@ -238,7 +242,7 @@ def dump_graph(self): for k, v in self.relations.items() } - def _setdefault(self, relation): + def _setdefault(self, relation: _CachedRelation): """Add a relation to the cache, or return it if it already exists. :param _CachedRelation relation: The relation to set or get. @@ -275,6 +279,8 @@ def _add_link(self, referenced_key, dependent_key): .format(dependent_key) ) + assert dependent is not None # we just raised! + referenced.add_reference(dependent) def add_link(self, referenced, dependent): @@ -305,7 +311,7 @@ def add_link(self, referenced, dependent): if ref_key not in self.relations: # Insert a dummy "external" relation. referenced = referenced.replace( - type=referenced.RelationType.External + type=referenced.External ) self.add(referenced) @@ -313,7 +319,7 @@ def add_link(self, referenced, dependent): if dep_key not in self.relations: # Insert a dummy "external" relation. dependent = dependent.replace( - type=referenced.RelationType.External + type=referenced.External ) self.add(dependent) logger.debug( @@ -469,7 +475,9 @@ def rename(self, old, new): lazy_log('after rename: {!s}', self.dump_graph) - def get_relations(self, database, schema): + def get_relations( + self, database: Optional[str], schema: Optional[str] + ) -> List[Any]: """Case-insensitively yield all relations matching the given schema. :param str schema: The case-insensitive schema name to list from. @@ -498,7 +506,7 @@ def clear(self): self.schemas.clear() def _list_relations_in_schema( - self, database: str, schema: str + self, database: Optional[str], schema: Optional[str] ) -> List[_CachedRelation]: """Get the relations in a schema. Callers should hold the lock.""" key = (_lower(database), _lower(schema)) diff --git a/core/dbt/adapters/factory.py b/core/dbt/adapters/factory.py index 0742a1e0c69..b9ccf714018 100644 --- a/core/dbt/adapters/factory.py +++ b/core/dbt/adapters/factory.py @@ -45,7 +45,8 @@ def load_plugin(self, name: str) -> Type[Credentials]: # and adapter_type entries with the same value, as they're all # singletons try: - mod = import_module('.' + name, 'dbt.adapters') + # mypy doesn't think modules have any attributes. + mod: Any = import_module('.' + name, 'dbt.adapters') except ModuleNotFoundError as exc: # if we failed to import the target module in particular, inform # the user about it via a runtiem error @@ -56,7 +57,7 @@ def load_plugin(self, name: str) -> Type[Credentials]: # library. Log the stack trace. logger.debug('', exc_info=True) raise - plugin = mod.Plugin # type: AdapterPlugin + plugin: AdapterPlugin = mod.Plugin plugin_type = plugin.adapter.type() if plugin_type != name: diff --git a/core/dbt/adapters/sql/impl.py b/core/dbt/adapters/sql/impl.py index b0542edf60a..cb79ff9348a 100644 --- a/core/dbt/adapters/sql/impl.py +++ b/core/dbt/adapters/sql/impl.py @@ -212,9 +212,9 @@ def list_relations_without_caching( } for _database, name, _schema, _type in results: try: - _type = self.Relation.RelationType(_type) + _type = self.Relation.get_relation_type(_type) except ValueError: - _type = self.Relation.RelationType.External + _type = self.Relation.External relations.append(self.Relation.create( database=_database, schema=_schema, diff --git a/core/dbt/clients/_jinja_blocks.py b/core/dbt/clients/_jinja_blocks.py index 436e8ccda22..5e92faf45ee 100644 --- a/core/dbt/clients/_jinja_blocks.py +++ b/core/dbt/clients/_jinja_blocks.py @@ -347,6 +347,7 @@ def find_blocks(self, allowed_blocks=None, collect_raw_data=True): elif self.is_current_end(tag): self.last_position = tag.end + assert self.current is not None yield BlockTag( block_type_name=self.current.block_type_name, block_name=self.current.block_name, diff --git a/core/dbt/clients/git.py b/core/dbt/clients/git.py index 5aa0baa33a6..e0d98d59461 100644 --- a/core/dbt/clients/git.py +++ b/core/dbt/clients/git.py @@ -84,6 +84,10 @@ def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False, logger.debug('Updating existing dependency {}.', directory) else: matches = re.match("Cloning into '(.+)'", err.decode('utf-8')) + if matches is None: + raise dbt.exceptions.RuntimeException( + f'Error cloning {repo} - never saw "Cloning into ..." from git' + ) directory = matches.group(1) logger.debug('Pulling new dependency {}.', directory) full_path = os.path.join(cwd, directory) diff --git a/core/dbt/clients/jinja.py b/core/dbt/clients/jinja.py index bf3943b9bb4..85c19548d82 100644 --- a/core/dbt/clients/jinja.py +++ b/core/dbt/clients/jinja.py @@ -3,7 +3,9 @@ import os import tempfile from contextlib import contextmanager -from typing import List, Union, Set, Optional, Dict, Any, Callable, Iterator +from typing import ( + List, Union, Set, Optional, Dict, Any, Callable, Iterator, Type +) import jinja2 import jinja2._compat @@ -35,17 +37,22 @@ def _linecache_inject(source, write): tmp_file.write(source) filename = tmp_file.name else: - filename = codecs.encode(os.urandom(12), 'hex').decode('ascii') + # `codecs.encode` actually takes a `bytes` as the first argument if + # the second argument is 'hex' - mypy does not know this. + rnd = codecs.encode(os.urandom(12), 'hex') # type: ignore + filename = rnd.decode('ascii') # encode, though I don't think this matters filename = jinja2._compat.encode_filename(filename) # put ourselves in the cache - linecache.cache[filename] = ( + cache_entry = ( len(source), None, [line + '\n' for line in source.splitlines()], filename ) + # linecache does in fact have an attribute `cache`, thanks + linecache.cache[filename] = cache_entry # type: ignore return filename @@ -84,7 +91,7 @@ def _compile(self, source, filename): write = MACRO_DEBUGGING == 'write' filename = _linecache_inject(source, write) - return super()._compile(source, filename) + return super()._compile(source, filename) # type: ignore class TemplateCache: @@ -317,7 +324,7 @@ def __call__(self, *args, **kwargs): def get_environment(node=None, capture_macros=False): - args = { + args: Dict[str, List[Union[str, Type[jinja2.ext.Extension]]]] = { 'extensions': ['jinja2.ext.do'] } @@ -330,38 +337,34 @@ def get_environment(node=None, capture_macros=False): return MacroFuzzEnvironment(**args) -def parse(string): +@contextmanager +def catch_jinja(node=None) -> Iterator[None]: try: - return get_environment().parse(str(string)) - - except (jinja2.exceptions.TemplateSyntaxError, - jinja2.exceptions.UndefinedError) as e: + yield + except jinja2.exceptions.TemplateSyntaxError as e: e.translated = False - dbt.exceptions.raise_compiler_error(str(e)) + raise dbt.exceptions.CompilationException(str(e), node) from e + except jinja2.exceptions.UndefinedError as e: + raise dbt.exceptions.CompilationException(str(e), node) from e + + +def parse(string): + with catch_jinja(): + return get_environment().parse(str(string)) def get_template(string, ctx, node=None, capture_macros=False): - try: + with catch_jinja(node): env = get_environment(node, capture_macros) template_source = str(string) return env.from_string(template_source, globals=ctx) - except (jinja2.exceptions.TemplateSyntaxError, - jinja2.exceptions.UndefinedError) as e: - e.translated = False - dbt.exceptions.raise_compiler_error(str(e), node) - def render_template(template, ctx, node=None): - try: + with catch_jinja(node): return template.render(ctx) - except (jinja2.exceptions.TemplateSyntaxError, - jinja2.exceptions.UndefinedError) as e: - e.translated = False - dbt.exceptions.raise_compiler_error(str(e), node) - def get_rendered(string, ctx, node=None, capture_macros=False): diff --git a/core/dbt/clients/system.py b/core/dbt/clients/system.py index c642ac1600e..713e4d4d40d 100644 --- a/core/dbt/clients/system.py +++ b/core/dbt/clients/system.py @@ -9,6 +9,9 @@ import tarfile import requests import stat +from typing import ( + Type, NoReturn, List, Optional, Dict, Any, Tuple, Callable +) import dbt.exceptions import dbt.utils @@ -16,9 +19,11 @@ from dbt.logger import GLOBAL_LOGGER as logger -def find_matching(root_path, - relative_paths_to_search, - file_pattern): +def find_matching( + root_path: str, + relative_paths_to_search: List[str], + file_pattern: str, +) -> List[Dict[str, str]]: """ Given an absolute `root_path`, a list of relative paths to that absolute root path (`relative_paths_to_search`), and a `file_pattern` @@ -58,7 +63,7 @@ def find_matching(root_path, return matching -def load_file_contents(path, strip=True): +def load_file_contents(path: str, strip: bool = True) -> str: with open(path, 'rb') as handle: to_return = handle.read().decode('utf-8') @@ -68,7 +73,7 @@ def load_file_contents(path, strip=True): return to_return -def make_directory(path): +def make_directory(path: str) -> None: """ Make a directory and any intermediate directories that don't already exist. This function handles the case where two threads try to create @@ -86,7 +91,7 @@ def make_directory(path): raise e -def make_file(path, contents='', overwrite=False): +def make_file(path: str, contents: str = '', overwrite: bool = False) -> bool: """ Make a file at `path` assuming that the directory it resides in already exists. The file is saved with contents `contents` @@ -99,21 +104,21 @@ def make_file(path, contents='', overwrite=False): return False -def make_symlink(source, link_path): +def make_symlink(source: str, link_path: str) -> None: """ Create a symlink at `link_path` referring to `source`. """ if not supports_symlinks(): dbt.exceptions.system_error('create a symbolic link') - return os.symlink(source, link_path) + os.symlink(source, link_path) -def supports_symlinks(): +def supports_symlinks() -> bool: return getattr(os, "symlink", None) is not None -def write_file(path, contents=''): +def write_file(path: str, contents: str = '') -> bool: make_directory(os.path.dirname(path)) with open(path, 'w', encoding='utf-8') as f: f.write(str(contents)) @@ -121,11 +126,13 @@ def write_file(path, contents=''): return True -def write_json(path, data): +def write_json(path: str, data: Dict[str, Any]) -> bool: return write_file(path, json.dumps(data, cls=dbt.utils.JSONEncoder)) -def _windows_rmdir_readonly(func, path, exc): +def _windows_rmdir_readonly( + func: Callable[[str], Any], path: str, exc: Tuple[Any, OSError, Any] +): exception_val = exc[1] if exception_val.errno == errno.EACCES: os.chmod(path, stat.S_IWUSR) @@ -134,7 +141,7 @@ def _windows_rmdir_readonly(func, path, exc): raise -def resolve_path_from_base(path_to_resolve, base_path): +def resolve_path_from_base(path_to_resolve: str, base_path: str) -> str: """ If path-to_resolve is a relative path, create an absolute path with base_path as the base. @@ -148,7 +155,7 @@ def resolve_path_from_base(path_to_resolve, base_path): os.path.expanduser(path_to_resolve))) -def rmdir(path): +def rmdir(path: str) -> None: """ Recursively deletes a directory. Includes an error handler to retry with different permissions on Windows. Otherwise, removing directories (eg. @@ -160,22 +167,22 @@ def rmdir(path): else: onerror = None - return shutil.rmtree(path, onerror=onerror) + shutil.rmtree(path, onerror=onerror) -def remove_file(path): - return os.remove(path) +def remove_file(path: str) -> None: + os.remove(path) -def path_exists(path): +def path_exists(path: str) -> bool: return os.path.lexists(path) -def path_is_symlink(path): +def path_is_symlink(path: str) -> bool: return os.path.islink(path) -def open_dir_cmd(): +def open_dir_cmd() -> str: # https://docs.python.org/2/library/sys.html#sys.platform if sys.platform == 'win32': return 'start' @@ -187,7 +194,9 @@ def open_dir_cmd(): return 'xdg-open' -def _handle_posix_cwd_error(exc, cwd, cmd): +def _handle_posix_cwd_error( + exc: OSError, cwd: str, cmd: List[str] +) -> NoReturn: if exc.errno == errno.ENOENT: message = 'Directory does not exist' elif exc.errno == errno.EACCES: @@ -199,7 +208,9 @@ def _handle_posix_cwd_error(exc, cwd, cmd): raise dbt.exceptions.WorkingDirectoryError(cwd, cmd, message) -def _handle_posix_cmd_error(exc, cwd, cmd): +def _handle_posix_cmd_error( + exc: OSError, cwd: str, cmd: List[str] +) -> NoReturn: if exc.errno == errno.ENOENT: message = "Could not find command, ensure it is in the user's PATH" elif exc.errno == errno.EACCES: @@ -209,7 +220,7 @@ def _handle_posix_cmd_error(exc, cwd, cmd): raise dbt.exceptions.ExecutableError(cwd, cmd, message) -def _handle_posix_error(exc, cwd, cmd): +def _handle_posix_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn: """OSError handling for posix systems. Some things that could happen to trigger an OSError: @@ -236,8 +247,8 @@ def _handle_posix_error(exc, cwd, cmd): _handle_posix_cmd_error(exc, cwd, cmd) -def _handle_windows_error(exc, cwd, cmd): - cls = dbt.exceptions.CommandError +def _handle_windows_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn: + cls: Type[dbt.exceptions.Exception] = dbt.exceptions.CommandError if exc.errno == errno.ENOENT: message = ("Could not find command, ensure it is in the user's PATH " "and that the user has permissions to run it") @@ -256,7 +267,7 @@ def _handle_windows_error(exc, cwd, cmd): raise cls(cwd, cmd, message) -def _interpret_oserror(exc, cwd, cmd): +def _interpret_oserror(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn: """Interpret an OSError exc and raise the appropriate dbt exception. """ @@ -275,7 +286,9 @@ def _interpret_oserror(exc, cwd, cmd): ) -def run_cmd(cwd, cmd, env=None): +def run_cmd( + cwd: str, cmd: List[str], env: Optional[Dict[str, Any]] = None +) -> Tuple[bytes, bytes]: logger.debug('Executing "{}"'.format(' '.join(cmd))) if len(cmd) == 0: raise dbt.exceptions.CommandError(cwd, cmd) @@ -299,8 +312,8 @@ def run_cmd(cwd, cmd, env=None): except OSError as exc: _interpret_oserror(exc, cwd, cmd) - logger.debug('STDOUT: "{}"'.format(out)) - logger.debug('STDERR: "{}"'.format(err)) + logger.debug('STDOUT: "{!s}"'.format(out)) + logger.debug('STDERR: "{!s}"'.format(err)) if proc.returncode != 0: logger.debug('command return code={}'.format(proc.returncode)) @@ -310,14 +323,14 @@ def run_cmd(cwd, cmd, env=None): return out, err -def download(url, path): +def download(url: str, path: str) -> None: response = requests.get(url) with open(path, 'wb') as handle: for block in response.iter_content(1024 * 64): handle.write(block) -def rename(from_path, to_path, force=False): +def rename(from_path: str, to_path: str, force: bool = False) -> None: is_symlink = path_is_symlink(to_path) if os.path.exists(to_path) and force: @@ -329,7 +342,9 @@ def rename(from_path, to_path, force=False): shutil.move(from_path, to_path) -def untar_package(tar_path, dest_dir, rename_to=None): +def untar_package( + tar_path: str, dest_dir: str, rename_to: Optional[str] = None +) -> None: tar_dir_name = None with tarfile.open(tar_path, 'r') as tarball: tarball.extractall(dest_dir) diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 1aeee4a7774..34a412337e9 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -1,6 +1,7 @@ import itertools import os from collections import defaultdict +from typing import List, Dict import dbt.utils import dbt.include @@ -63,7 +64,7 @@ def _node_enabled(node): def _generate_stats(manifest): - stats = defaultdict(int) + stats: Dict[NodeType, int] = defaultdict(int) for node_name, node in itertools.chain( manifest.nodes.items(), manifest.macros.items()): @@ -102,7 +103,7 @@ def recursively_prepend_ctes(model, manifest): 'Bad model type: {}'.format(type(model)) ) - prepended_ctes = [] + prepended_ctes: List[InjectedCTE] = [] for cte in model.extra_ctes: cte_id = cte.id @@ -227,7 +228,7 @@ def compile(self, manifest, write=True): return linker -def compile_manifest(config, manifest, write=True): +def compile_manifest(config, manifest, write=True) -> Linker: compiler = Compiler(config) compiler.initialize() return compiler.compile(manifest, write=write) @@ -273,7 +274,9 @@ def _inject_runtime_config(adapter, node, extra_context): def _node_context(adapter, node): - return { - "run_started_at": dbt.tracking.active_user.run_started_at, - "invocation_id": dbt.tracking.active_user.invocation_id, - } + if dbt.tracking.active_user is not None: + return { + "run_started_at": dbt.tracking.active_user.run_started_at, + "invocation_id": dbt.tracking.active_user.invocation_id, + } + return {} # this never happens, but make mypy happy diff --git a/core/dbt/config/profile.py b/core/dbt/config/profile.py index 81a04ba171b..7d2446b3f14 100644 --- a/core/dbt/config/profile.py +++ b/core/dbt/config/profile.py @@ -62,17 +62,14 @@ def read_profile(profiles_dir: str) -> Dict[str, Any]: def read_user_config(directory: str) -> UserConfig: try: - user_cfg = None profile = read_profile(directory) if profile: user_cfg = coerce_dict_str(profile.get('config', {})) if user_cfg is not None: return UserConfig.from_dict(user_cfg) - return UserConfig() - - return UserConfig.from_dict(user_cfg) except (RuntimeException, ValidationError): - return UserConfig() + pass + return UserConfig() @dataclass diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index 34761014412..8e16c39ac65 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -1,7 +1,7 @@ import itertools import json import os -from typing import Callable, Any, Dict, List, Optional +from typing import Callable, Any, Dict, List, Optional, Mapping import dbt.tracking from dbt.clients.jinja import undefined_error @@ -248,7 +248,7 @@ def search_package_name(self): def add_macros_from( self, context: Dict[str, Any], - macros: Dict[str, ParsedMacro] + macros: Mapping[str, ParsedMacro], ): global_macros: List[Dict[str, Callable]] = [] local_macros: List[Dict[str, Callable]] = [] @@ -284,7 +284,7 @@ class QueryHeaderContext(HasCredentialsContext): def __init__(self, config): super().__init__(config) - def to_dict(self, macros: Optional[Dict[str, ParsedMacro]] = None): + def to_dict(self, macros: Optional[Mapping[str, ParsedMacro]] = None): context = super().to_dict() context['target'] = self.get_target() context['dbt_version'] = dbt_version diff --git a/core/dbt/context/common.py b/core/dbt/context/common.py index 9b4a743a8f0..c029b91eea6 100644 --- a/core/dbt/context/common.py +++ b/core/dbt/context/common.py @@ -145,8 +145,8 @@ def inner(value: T) -> None: }) -def add_sql_handlers(context): - sql_results = {} +def add_sql_handlers(context: Dict[str, Any]) -> None: + sql_results: Dict[str, Any] = {} context['_sql_results'] = sql_results context['store_result'] = _store_result(sql_results) context['load_result'] = _load_result(sql_results) diff --git a/core/dbt/context/runtime.py b/core/dbt/context/runtime.py index e7c42c73ce4..cac013a13f2 100644 --- a/core/dbt/context/runtime.py +++ b/core/dbt/context/runtime.py @@ -90,7 +90,7 @@ def __init__(self, model, source_config=None): self.model = model # we never use or get a source config, only the parser cares - def __call__(*args, **kwargs): + def __call__(self, *args, **kwargs): return '' def set(self, name, value): diff --git a/core/dbt/contracts/connection.py b/core/dbt/contracts/connection.py index 4628c5b77a1..4dad8a97dcf 100644 --- a/core/dbt/contracts/connection.py +++ b/core/dbt/contracts/connection.py @@ -2,7 +2,7 @@ import itertools from dataclasses import dataclass, field from typing import ( - Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List, Type + Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List, Callable ) from typing_extensions import Protocol @@ -27,23 +27,6 @@ class ConnectionState(StrEnum): FAIL = 'fail' -class ConnectionOpenerProtocol(Protocol): - @classmethod - def open(cls, connection: 'Connection') -> Any: - raise NotImplementedError(f'open() not implemented for {cls.__name__}') - - -class LazyHandle: - """Opener must be a callable that takes a Connection object and opens the - connection, updating the handle on the Connection. - """ - def __init__(self, opener: Type[ConnectionOpenerProtocol]): - self.opener = opener - - def resolve(self, connection: 'Connection') -> Any: - return self.opener.open(connection) - - @dataclass(init=False) class Connection(ExtensibleJsonSchemaMixin, Replaceable): type: Identifier @@ -96,6 +79,17 @@ def handle(self, value): self._handle = value +class LazyHandle: + """Opener must be a callable that takes a Connection object and opens the + connection, updating the handle on the Connection. + """ + def __init__(self, opener: Callable[[Connection], Connection]): + self.opener = opener + + def resolve(self, connection: Connection) -> Connection: + return self.opener(connection) + + # see https://github.com/python/mypy/issues/4717#issuecomment-373932080 # and https://github.com/python/mypy/issues/5374 # for why we have type: ignore. Maybe someday dataclasses + abstract classes diff --git a/core/dbt/contracts/graph/compiled.py b/core/dbt/contracts/graph/compiled.py index 2fada422c35..a5f9637a66c 100644 --- a/core/dbt/contracts/graph/compiled.py +++ b/core/dbt/contracts/graph/compiled.py @@ -3,6 +3,7 @@ ParsedAnalysisNode, ParsedModelNode, ParsedHookNode, + ParsedResource, ParsedRPCNode, ParsedSeedNode, ParsedSnapshotNode, @@ -200,7 +201,7 @@ def compiled_type_for(parsed: ParsedNode): return type(parsed) -def parsed_instance_for(compiled: CompiledNode) -> ParsedNode: +def parsed_instance_for(compiled: CompiledNode) -> ParsedResource: cls = PARSED_TYPES.get(compiled.resource_type) if cls is None: # how??? diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 5ccd3a8514d..1ee5ff26f21 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -3,13 +3,17 @@ import os from dataclasses import dataclass, field from datetime import datetime -from typing import Dict, List, Optional, Union, Mapping, Any +from typing import ( + Dict, List, Optional, Union, Mapping, MutableMapping, Any, Set, Tuple +) from uuid import UUID from hologram import JsonSchemaMixin -from dbt.contracts.graph.parsed import ParsedNode, ParsedMacro, \ - ParsedDocumentation +from dbt.contracts.graph.parsed import ( + ParsedNode, ParsedMacro, ParsedDocumentation, ParsedNodePatch, + ParsedSourceDefinition +) from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.util import Writable, Replaceable from dbt.exceptions import ( @@ -200,9 +204,9 @@ def build_edges(nodes): and return them as two separate dictionaries, each mapping unique IDs to lists of edges. """ - backward_edges = {} + backward_edges: Dict[str, List[str]] = {} # pre-populate the forward edge dict for simplicity - forward_edges = {node.unique_id: [] for node in nodes} + forward_edges: Dict[str, List[str]] = {n.unique_id: [] for n in nodes} for node in nodes: backward_edges[node.unique_id] = node.depends_on_nodes[:] for unique_id in node.depends_on_nodes: @@ -262,17 +266,21 @@ def __lt__(self, other: 'MaterializationCandidate') -> bool: class Manifest: """The manifest for the full graph, after parsing and during compilation. """ - nodes: Mapping[str, CompileResultNode] - macros: Mapping[str, ParsedMacro] - docs: Mapping[str, ParsedDocumentation] + nodes: MutableMapping[str, CompileResultNode] + macros: MutableMapping[str, ParsedMacro] + docs: MutableMapping[str, ParsedDocumentation] generated_at: datetime disabled: List[ParsedNode] - files: Mapping[str, SourceFile] + files: MutableMapping[str, SourceFile] metadata: ManifestMetadata = field(default_factory=ManifestMetadata) flat_graph: Dict[str, Any] = field(default_factory=dict) @classmethod - def from_macros(cls, macros=None, files=None) -> 'Manifest': + def from_macros( + cls, + macros: Optional[MutableMapping[str, ParsedMacro]] = None, + files: Optional[MutableMapping[str, SourceFile]] = None, + ) -> 'Manifest': if macros is None: macros = {} if files is None: @@ -322,6 +330,10 @@ def _find_by_name(self, name, package, subgraph, nodetype): None, all pacakges will be searched. nodetype should be a list of NodeTypes to accept. """ + search: Union[ + MutableMapping[str, ParsedMacro], + MutableMapping[str, CompileResultNode], + ] if subgraph == 'nodes': search = self.nodes elif subgraph == 'macros': @@ -411,8 +423,8 @@ def get_materialization_macro( candidates.sort() return candidates[-1].macro - def get_resource_fqns(self): - resource_fqns = {} + def get_resource_fqns(self) -> Dict[str, Set[Tuple[str, ...]]]: + resource_fqns: Dict[str, Set[Tuple[str, ...]]] = {} for unique_id, node in self.nodes.items(): if node.resource_type == NodeType.Source: continue # sources have no FQNs and can't be configured @@ -430,7 +442,7 @@ def add_nodes(self, new_nodes): raise_duplicate_resource_name(node, self.nodes[unique_id]) self.nodes[unique_id] = node - def patch_nodes(self, patches): + def patch_nodes(self, patches: MutableMapping[str, ParsedNodePatch]): """Patch nodes with the given dict of patches. Note that this consumes the input! This relies on the fact that all nodes have unique _name_ fields, not @@ -443,12 +455,13 @@ def patch_nodes(self, patches): for node in self.nodes.values(): if node.resource_type == NodeType.Source: continue + # we know this because of the check above + assert not isinstance(node, ParsedSourceDefinition) patch = patches.pop(node.name, None) if not patch: continue + expected_key = node.resource_type.pluralize() - if expected_key == patch.yaml_key: - node.patch(patch) if expected_key != patch.yaml_key: if patch.yaml_key == 'models': deprecations.warn( @@ -469,6 +482,7 @@ def patch_nodes(self, patches): ''' ) raise_compiler_error(msg) + node.patch(patch) # log debug-level warning about nodes we couldn't find diff --git a/core/dbt/contracts/graph/parsed.py b/core/dbt/contracts/graph/parsed.py index a92db585108..3ba7fefb007 100644 --- a/core/dbt/contracts/graph/parsed.py +++ b/core/dbt/contracts/graph/parsed.py @@ -105,7 +105,7 @@ def __delitem__(self, key): del self._extra[key] def __iter__(self): - for fld in self._get_fields(): + for fld, _ in self._get_fields(): yield fld.name for key in self._extra: @@ -157,7 +157,11 @@ class HasRelationMetadata(JsonSchemaMixin, Replaceable): schema: str -class ParsedNodeMixins: +class ParsedNodeMixins(JsonSchemaMixin): + resource_type: NodeType + depends_on: DependsOn + config: NodeConfig + @property def is_refable(self): return self.resource_type in NodeType.refable() @@ -431,11 +435,11 @@ class ParsedSnapshotNode(ParsedNode): ] @classmethod - def json_schema(cls, embeddable=False): + def json_schema(cls, embeddable: bool = False) -> Dict[str, Any]: schema = super().json_schema(embeddable) # mess with config - configs = [ + configs: List[Tuple[str, Type[JsonSchemaMixin]]] = [ (str(SnapshotStrategy.Check), CheckSnapshotConfig), (str(SnapshotStrategy.Timestamp), TimestampSnapshotConfig), ] @@ -516,6 +520,14 @@ class ParsedSourceDefinition( source_meta: Dict[str, Any] = field(default_factory=dict) tags: List[str] = field(default_factory=list) + @property + def is_refable(self): + return False + + @property + def is_ephemeral(self): + return False + @property def is_ephemeral_model(self): return False diff --git a/core/dbt/contracts/results.py b/core/dbt/contracts/results.py index b3c532e5a6c..69226debd8a 100644 --- a/core/dbt/contracts/results.py +++ b/core/dbt/contracts/results.py @@ -1,7 +1,10 @@ from dbt.contracts.graph.manifest import CompileResultNode -from dbt.contracts.graph.unparsed import Time, FreshnessStatus +from dbt.contracts.graph.unparsed import ( + Time, FreshnessStatus, FreshnessThreshold +) from dbt.contracts.graph.parsed import ParsedSourceDefinition from dbt.contracts.util import Writable, Replaceable +from dbt.exceptions import InternalException from dbt.logger import ( TimingProcessor, JsonOnly, @@ -15,7 +18,6 @@ from dataclasses import dataclass, field from datetime import datetime from typing import Union, Dict, List, Optional, Any, NamedTuple -from numbers import Real @dataclass @@ -51,7 +53,7 @@ class PartialResult(JsonSchemaMixin, Writable): error: Optional[str] = None status: Union[None, str, int, bool] = None execution_time: Union[str, int] = 0 - thread_id: Optional[int] = 0 + thread_id: Optional[str] = None timing: List[TimingInfo] = field(default_factory=list) fail: Optional[bool] = None warn: Optional[bool] = None @@ -82,7 +84,7 @@ def to_dict(self, *args, **kwargs): class ExecutionResult(JsonSchemaMixin, Writable): results: List[Union[WritableRunModelResult, PartialResult]] generated_at: datetime - elapsed_time: Real + elapsed_time: float def __len__(self): return len(self.results) @@ -94,6 +96,11 @@ def __getitem__(self, idx): return self.results[idx] +@dataclass +class RunOperationResult(ExecutionResult): + success: bool + + # due to issues with typing.Union collapsing subclasses, this can't subclass # PartialResult @dataclass @@ -101,11 +108,11 @@ class SourceFreshnessResult(JsonSchemaMixin, Writable): node: ParsedSourceDefinition max_loaded_at: datetime snapshotted_at: datetime - age: Real + age: float status: FreshnessStatus error: Optional[str] = None execution_time: Union[str, int] = 0 - thread_id: Optional[int] = 0 + thread_id: Optional[str] = None timing: List[TimingInfo] = field(default_factory=list) fail: Optional[bool] = None @@ -124,7 +131,7 @@ def skipped(self): @dataclass class FreshnessMetadata(JsonSchemaMixin): generated_at: datetime - elapsed_time: Real + elapsed_time: float @dataclass @@ -139,6 +146,9 @@ def write(self, path, omit_none=True): ) sources = {} for result in self.results: + result_value: Union[ + SourceFreshnessRuntimeError, SourceFreshnessOutput + ] unique_id = result.node.unique_id if result.error is not None: result_value = SourceFreshnessRuntimeError( @@ -146,12 +156,26 @@ def write(self, path, omit_none=True): state=FreshnessErrorEnum.runtime_error, ) else: + # we know that this must be a SourceFreshnessResult + if not isinstance(result, SourceFreshnessResult): + raise InternalException( + 'Got {} instead of a SourceFreshnessResult for a ' + 'non-error result in freshness execution!' + .format(type(result)) + ) + # if we're here, we must have a non-None freshness threshold + criteria = result.node.freshness + if criteria is None: + raise InternalException( + 'Somehow evaluated a freshness result for a source ' + 'that has no freshness criteria!' + ) result_value = SourceFreshnessOutput( max_loaded_at=result.max_loaded_at, snapshotted_at=result.snapshotted_at, max_loaded_at_time_ago_in_s=result.age, state=result.status, - criteria=result.node.freshness, + criteria=criteria, ) sources[unique_id] = result_value output = FreshnessRunOutput(meta=meta, sources=sources) @@ -191,9 +215,9 @@ class SourceFreshnessRuntimeError(JsonSchemaMixin): class SourceFreshnessOutput(JsonSchemaMixin): max_loaded_at: datetime snapshotted_at: datetime - max_loaded_at_time_ago_in_s: Real + max_loaded_at_time_ago_in_s: float state: FreshnessStatus - criteria: FreshnessCriteria + criteria: FreshnessThreshold SourceFreshnessRunResult = Union[SourceFreshnessOutput, diff --git a/core/dbt/contracts/rpc.py b/core/dbt/contracts/rpc.py index a1dc474a284..23b152f5d34 100644 --- a/core/dbt/contracts/rpc.py +++ b/core/dbt/contracts/rpc.py @@ -3,7 +3,6 @@ import uuid from dataclasses import dataclass, field from datetime import datetime, timedelta -from numbers import Real from typing import Optional, Union, List, Any, Dict, Type from hologram import JsonSchemaMixin @@ -28,7 +27,7 @@ @dataclass class RPCParameters(JsonSchemaMixin): - timeout: Optional[Real] + timeout: Optional[float] task_tags: TaskTags @@ -185,7 +184,7 @@ class ResultTable(JsonSchemaMixin): @dataclass -class RemoteRunOperationResult(RemoteResult): +class RemoteRunOperationResult(ExecutionResult, RemoteResult): success: bool @@ -199,6 +198,7 @@ class RemoteRunResult(RemoteCompileResult): RemoteExecutionResult, RemoteCatalogResults, RemoteEmptyResult, + RemoteRunOperationResult, ] @@ -480,6 +480,9 @@ def from_result( ) -> 'PollRunOperationCompleteResult': return cls( success=base.success, + results=base.results, + generated_at=base.generated_at, + elapsed_time=base.elapsed_time, logs=logs, tags=tags, state=timing.state, diff --git a/core/dbt/deps/resolver.py b/core/dbt/deps/resolver.py index 8237463a48c..b499f407f18 100644 --- a/core/dbt/deps/resolver.py +++ b/core/dbt/deps/resolver.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Dict, List, NoReturn, Union, Type, Iterator +from typing import Dict, List, NoReturn, Union, Type, Iterator, Set from dbt.exceptions import raise_dependency_error, InternalException @@ -98,7 +98,7 @@ def __iter__(self) -> Iterator[UnpinnedPackage]: def _check_for_duplicate_project_names( final_deps: List[PinnedPackage], config: Project ): - seen = set() + seen: Set[str] = set() for package in final_deps: project_name = package.get_project_name(config) if project_name in seen: diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index 40aceccbb98..a2ace5536f8 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -801,11 +801,10 @@ def wrap(func): def inner(*args, **kwargs): try: return func(*args, **kwargs) - except Exception as exc: - if hasattr(exc, 'node') and exc.node is None: + except RuntimeException as exc: + if exc.node is None: exc.node = model raise exc - return inner return wrap diff --git a/core/dbt/graph/selector.py b/core/dbt/graph/selector.py index 595b0d122bf..147b6a2ef65 100644 --- a/core/dbt/graph/selector.py +++ b/core/dbt/graph/selector.py @@ -1,5 +1,6 @@ from enum import Enum from itertools import chain +from typing import Set, Iterable, Union, List, Container, Tuple, Optional import networkx as nx # type: ignore @@ -16,7 +17,7 @@ class SelectionCriteria: - def __init__(self, node_spec): + def __init__(self, node_spec: str): self.raw = node_spec self.select_children = False self.select_parents = False @@ -43,7 +44,8 @@ def __init__(self, node_spec): if SELECTOR_DELIMITER in node_spec: selector_parts = node_spec.split(SELECTOR_DELIMITER, 1) - self.selector_type, self.selector_value = selector_parts + selector_type, self.selector_value = selector_parts + self.selector_type = SELECTOR_FILTERS(selector_type) else: self.selector_value = node_spec @@ -57,8 +59,8 @@ def __str__(self): return self._value_ -def split_specs(node_specs): - specs = set() +def split_specs(node_specs: Iterable[str]): + specs: Set[str] = set() for spec in node_specs: parts = spec.split(" ") specs.update(parts) @@ -96,7 +98,9 @@ def is_selected_node(real_node, node_selector): return True -def _node_is_match(qualified_name, package_names, fqn): +def _node_is_match( + qualified_name: List[str], package_names: Set[str], fqn: List[str] +) -> bool: """Determine if a qualfied name matches an fqn, given the set of package names in the graph. @@ -120,22 +124,18 @@ def _node_is_match(qualified_name, package_names, fqn): return False -def warn_if_useless_spec(spec, nodes): - if len(nodes) > 0: - return - - msg = ( - "* Spec='{}' does not identify any models" - .format(spec['raw']) - ) - dbt.exceptions.warn_or_error(msg, log_fmt='{} and was ignored\n') - - class ManifestSelector: + FILTER: str + def __init__(self, manifest): self.manifest = manifest - def _node_iterator(self, included_nodes, exclude, include): + def _node_iterator( + self, + included_nodes: Set[str], + exclude: Optional[Container[str]], + include: Optional[Container[str]], + ) -> Iterable[Tuple[str, str]]: for unique_id, node in self.manifest.nodes.items(): if unique_id not in included_nodes: continue @@ -223,6 +223,9 @@ class InvalidSelectorError(Exception): pass +ValidSelector = Union[QualifiedNameSelector, TagSelector, SourceSelector] + + class MultiSelector: """The base class of the node selector. It only about the manifest and selector types, including the glob operator, but does not handle any graph @@ -233,7 +236,9 @@ class MultiSelector: def __init__(self, manifest): self.manifest = manifest - def get_selector(self, selector_type): + def get_selector( + self, selector_type: str + ): for cls in self.SELECTORS: if cls.FILTER == selector_type: return cls(self.manifest) @@ -258,30 +263,32 @@ def nodes(self): def __iter__(self): return iter(self.graph.nodes()) - def select_childrens_parents(self, selected): + def select_childrens_parents(self, selected: Set[str]) -> Set[str]: ancestors_for = self.select_children(selected) | selected return self.select_parents(ancestors_for) | ancestors_for - def select_children(self, selected): - descendants = set() + def select_children(self, selected: Set[str]) -> Set[str]: + descendants: Set[str] = set() for node in selected: descendants.update(nx.descendants(self.graph, node)) return descendants - def select_parents(self, selected): - ancestors = set() + def select_parents(self, selected: Set[str]) -> Set[str]: + ancestors: Set[str] = set() for node in selected: ancestors.update(nx.ancestors(self.graph, node)) return ancestors - def select_successors(self, selected): - successors = set() + def select_successors(self, selected: Set[str]) -> Set[str]: + successors: Set[str] = set() for node in selected: successors.update(self.graph.successors(node)) return successors - def collect_models(self, selected, spec): - additional = set() + def collect_models( + self, selected: Set[str], spec: SelectionCriteria, + ) -> Set[str]: + additional: Set[str] = set() if spec.select_childrens_parents: additional.update(self.select_childrens_parents(selected)) if spec.select_parents: @@ -290,7 +297,7 @@ def collect_models(self, selected, spec): additional.update(self.select_children(selected)) return additional - def subgraph(self, nodes): + def subgraph(self, nodes: Iterable[str]) -> 'Graph': cls = type(self) return cls(self.graph.subgraph(nodes)) @@ -326,7 +333,7 @@ def get_nodes_from_spec(self, graph, spec): return collected def select_nodes(self, graph, raw_include_specs, raw_exclude_specs): - selected_nodes = set() + selected_nodes: Set[str] = set() for raw_spec in split_specs(raw_include_specs): spec = SelectionCriteria(raw_spec) diff --git a/core/dbt/helper_types.py b/core/dbt/helper_types.py index 7484acb1f6f..5d9cf4cc53f 100644 --- a/core/dbt/helper_types.py +++ b/core/dbt/helper_types.py @@ -1,6 +1,5 @@ # never name this package "types", or mypy will crash in ugly ways from datetime import timedelta -from numbers import Real from typing import NewType, Dict from hologram import ( @@ -38,12 +37,6 @@ def json_schema(self) -> JsonDict: return {'type': 'number'} -class RealEncoder(FieldEncoder): - @property - def json_schema(self): - return {'type': 'number'} - - class NoValue: """Sometimes, you want a way to say none that isn't None""" def __eq__(self, other): @@ -80,6 +73,5 @@ def json_schema(self): JsonSchemaMixin.register_field_encoders({ Port: PortEncoder(), timedelta: TimeDeltaFieldEncoder(), - Real: RealEncoder(), NoValue: NoValueEncoder(), }) diff --git a/core/dbt/linker.py b/core/dbt/linker.py index e0e2d5d2776..991d704901d 100644 --- a/core/dbt/linker.py +++ b/core/dbt/linker.py @@ -36,7 +36,7 @@ def __init__(self, graph, manifest): self.graph = graph self.manifest = manifest # store the queue as a priority queue. - self.inner = PriorityQueue() + self.inner: PriorityQueue = PriorityQueue() # things that have been popped off the queue but not finished # and worker thread reservations self.in_progress = set() diff --git a/core/dbt/logger.py b/core/dbt/logger.py index c4a7b666f9c..086c1d5b88f 100644 --- a/core/dbt/logger.py +++ b/core/dbt/logger.py @@ -116,6 +116,11 @@ def format_text(self): self.formatter_class = logbook.StringFormatter self.format_string = self._text_format_string + def reset(self): + raise NotImplementedError( + 'reset() not implemented in FormatterMixin subclass' + ) + class OutputHandler(logbook.StreamHandler, FormatterMixin): """Output handler. @@ -372,7 +377,7 @@ def __init__( self._msg_buffer: Optional[List[logbook.LogRecord]] = [] # if we get 1k messages without a logfile being set, something is wrong self._bufmax = 1000 - self._log_path = None + self._log_path: Optional[str] = None # we need the base handler class' __init__ to run so handling works logbook.Handler.__init__(self, level, filter, bubble) if log_dir is not None: @@ -426,6 +431,8 @@ def _super_init(self, log_path): FormatterMixin.__init__(self, DEBUG_LOG_FORMAT) def _replay_buffered(self): + assert self._msg_buffer is not None, \ + '_msg_buffer should never be None in _replay_buffered' for record in self._msg_buffer: super().emit(record) self._msg_buffer = None diff --git a/core/dbt/main.py b/core/dbt/main.py index f578b959b8c..ac85621951c 100644 --- a/core/dbt/main.py +++ b/core/dbt/main.py @@ -1,3 +1,4 @@ +from typing import List from dbt.logger import GLOBAL_LOGGER as logger, log_cache_events, log_manager import argparse @@ -202,7 +203,8 @@ def run_from_args(parsed): log_path = getattr(task.config, 'log_path', None) # we can finally set the file logger up log_manager.set_path(log_path) - logger.debug("Tracking: {}".format(dbt.tracking.active_user.state())) + if dbt.tracking.active_user is not None: # mypy appeasement, always true + logger.debug("Tracking: {}".format(dbt.tracking.active_user.state())) results = None @@ -643,7 +645,9 @@ def _build_list_subparser(subparsers, base_subparser): aliases=['ls'], ) sub.set_defaults(cls=ListTask, which='list', rpc_method=None) - resource_values = list(ListTask.ALL_RESOURCE_VALUES) + ['default', 'all'] + resource_values: List[str] = [ + str(s) for s in ListTask.ALL_RESOURCE_VALUES + ] + ['default', 'all'] sub.add_argument('--resource-type', choices=resource_values, action='append', diff --git a/core/dbt/node_runners.py b/core/dbt/node_runners.py index 23951f68426..91aadd912bb 100644 --- a/core/dbt/node_runners.py +++ b/core/dbt/node_runners.py @@ -1,7 +1,8 @@ +import abc import threading import time import traceback -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional from dbt import deprecations from dbt.adapters.base import BaseRelation @@ -11,6 +12,7 @@ InternalException, missing_materialization ) from dbt.node_types import NodeType +from dbt.contracts.graph.manifest import Manifest from dbt.contracts.results import ( RunModelResult, collect_timing_info, SourceFreshnessResult, PartialResult, ) @@ -30,6 +32,8 @@ def track_model_run(index, num_nodes, run_model_result): + if dbt.tracking.active_user is None: + raise InternalException('cannot track model run with no active user') invocation_id = dbt.tracking.active_user.invocation_id dbt.tracking.track_model_run({ "invocation_id": invocation_id, @@ -59,7 +63,7 @@ def __init__(self, node): self.node = node -class BaseRunner: +class BaseRunner(metaclass=abc.ABCMeta): def __init__(self, config, adapter, node, node_index, num_nodes): self.config = config self.adapter = adapter @@ -68,7 +72,11 @@ def __init__(self, config, adapter, node, node_index, num_nodes): self.num_nodes = num_nodes self.skip = False - self.skip_cause = None + self.skip_cause: Optional[RunModelResult] = None + + @abc.abstractmethod + def compile(self, manifest: Manifest) -> Any: + pass def get_result_status(self, result) -> Dict[str, str]: if result.error: @@ -223,7 +231,10 @@ def safe_run(self, manifest): # if releasing failed and the result doesn't have an error yet, set # an error - if exc_str is not None and result.error is None: + if ( + exc_str is not None and result is not None and + result.error is None and error is None + ): error = exc_str if error is not None: @@ -284,6 +295,11 @@ def on_skip(self): self.num_nodes, self.skip_cause ) + if self.skip_cause is None: # mypy appeasement + raise InternalException( + 'Skip cause not set but skip was somehow caused by ' + 'an ephemeral failure' + ) # set an error so dbt will exit with an error code error = ( 'Compilation Error in {}, caused by compilation error ' diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 9cfaeede87d..f7edef92d55 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -346,8 +346,8 @@ def load_internal(cls, root_config: RuntimeConfig) -> Manifest: def _check_resource_uniqueness(manifest): - names_resources = {} - alias_resources = {} + names_resources: Dict[str, CompileResultNode] = {} + alias_resources: Dict[str, CompileResultNode] = {} for resource, node in manifest.nodes.items(): if node.resource_type not in NodeType.refable(): diff --git a/core/dbt/rpc/method.py b/core/dbt/rpc/method.py index fb1c6d83d59..b2bd1b36456 100644 --- a/core/dbt/rpc/method.py +++ b/core/dbt/rpc/method.py @@ -103,6 +103,11 @@ def __init__(self, task_manager): def set_args(self, params: Parameters): self.params = params + def run(self): + raise InternalException( + 'the run() method on builtins should never be called' + ) + def __call__(self, **kwargs: Dict[str, Any]) -> JsonSchemaMixin: try: params = self.get_parameters().from_dict(kwargs) diff --git a/core/dbt/rpc/task_handler.py b/core/dbt/rpc/task_handler.py index e7961db88b0..1a95097b0ab 100644 --- a/core/dbt/rpc/task_handler.py +++ b/core/dbt/rpc/task_handler.py @@ -19,6 +19,7 @@ from dbt.contracts.rpc import ( RPCParameters, RemoteResult, TaskHandlerState, RemoteMethodFlags, TaskTags, ) +from dbt.exceptions import InternalException from dbt.logger import ( GLOBAL_LOGGER as logger, list_handler, LogMessage, OutputHandler, ) @@ -112,7 +113,7 @@ def task_exec(self) -> None: elif result is not None: handler.emit_result(result) else: - error = dbt_error(dbt.exceptions.InternalException( + error = dbt_error(InternalException( 'after request handling, neither result nor error is None!' )) handler.emit_error(error.error) @@ -202,11 +203,12 @@ def handle_completed(self): if self.handler.result is None: # there wasn't an error before, but there sure is one now self.handler.error = dbt_error( - dbt.exceptions.InternalException( + InternalException( 'got an invalid result=None, but state was {}' .format(self.handler.state) ) ) + # TODO: need to tighten RequestTaskHandler.Task to also elif self.handler.task.interpret_results(self.handler.result): self.handler.state = TaskHandlerState.Success else: @@ -232,13 +234,13 @@ def handle_error(self, exc_type, exc_value, exc_tb) -> bool: def task_teardown(self): self.handler.task.cleanup(self.handler.result) - def __exit__(self, exc_type, exc_value, exc_tb) -> bool: + def __exit__(self, exc_type, exc_value, exc_tb) -> None: try: if exc_type is not None: self.handle_error(exc_type, exc_value, exc_tb) else: self.handle_completed() - return False + return finally: # we really really promise to run your teardown self.task_teardown() @@ -302,7 +304,7 @@ def request_id(self) -> Union[str, int]: @property def method(self) -> str: if self.task.METHOD_NAME is None: # mypy appeasement - raise dbt.exceptions.InternalException( + raise InternalException( f'In the request handler, got a task({self.task}) with no ' 'METHOD_NAME' ) @@ -338,7 +340,7 @@ def _wait_for_results(self) -> RemoteResult: self.started is None or self.process is None ): - raise dbt.exceptions.InternalException( + raise InternalException( '_wait_for_results() called before handle()' ) @@ -366,7 +368,7 @@ def _wait_for_results(self) -> RemoteResult: def get_result(self) -> RemoteResult: if self.process is None: - raise dbt.exceptions.InternalException( + raise InternalException( 'get_result() called before handle()' ) @@ -411,7 +413,7 @@ def handle_singlethreaded( # note this shouldn't call self.run() as that has different semantics # (we want errors to raise) if self.process is None: # mypy appeasement - raise dbt.exceptions.InternalException( + raise InternalException( 'Cannot run a None process' ) self.process.task_exec() @@ -430,6 +432,8 @@ def start(self): # calling close(), the connection in the parent ends up throwing # 'connection already closed' exceptions cleanup_connections() + if self.process is None: + raise InternalException('self.process is None in start()!') self.process.start() self.state = TaskHandlerState.Running super().start() @@ -438,6 +442,11 @@ def _collect_parameters(self): # both get_parameters and the argparse can raise a TypeError. cls: Type[RPCParameters] = self.task.get_parameters() + if self.task_kwargs is None: + raise TypeError( + 'task_kwargs were None - unable to cllect parameters' + ) + try: return cls.from_dict(self.task_kwargs) except ValidationError as exc: @@ -463,7 +472,7 @@ def handle(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: # tasks use this to set their `real_task`. self.task.set_config(self.manager.config) if self.task_params is None: # mypy appeasement - raise dbt.exceptions.InternalException( + raise InternalException( 'Task params set to None!' ) diff --git a/core/dbt/rpc/task_manager.py b/core/dbt/rpc/task_manager.py index 7fe7d8eac44..2becd7cb838 100644 --- a/core/dbt/rpc/task_manager.py +++ b/core/dbt/rpc/task_manager.py @@ -124,12 +124,6 @@ def reload_manifest(self) -> bool: self._reload_task_manager_thread(reloader) return True - def reload_builtin_tasks(self): - # reload all the non-manifest tasks because the config changed. - # manifest tasks are still blocked so we can ignore them - for task_cls in self._task_types.builtin(): - self.add_builtin_task_handler(task_cls) - def reload_config(self): config = self.config.from_args(self.args) self.config = config diff --git a/core/dbt/semver.py b/core/dbt/semver.py index 74439f8116d..2948975c651 100644 --- a/core/dbt/semver.py +++ b/core/dbt/semver.py @@ -64,115 +64,6 @@ class VersionSpecification(JsonSchemaMixin): _VERSION_REGEX = re.compile(_VERSION_REGEX_PAT_STR, re.VERBOSE) -class VersionRange(dbt.utils.AttrDict): - - def _try_combine_exact(self, a, b): - if a.compare(b) == 0: - return a - else: - raise VersionsNotCompatibleException() - - def _try_combine_lower_bound_with_exact(self, lower, exact): - comparison = lower.compare(exact) - - if (comparison < 0 or - (comparison == 0 and - lower.matcher == Matchers.GREATER_THAN_OR_EQUAL)): - return exact - - raise VersionsNotCompatibleException() - - def _try_combine_lower_bound(self, a, b): - if b.is_unbounded: - return a - elif a.is_unbounded: - return b - - if not (a.is_exact or b.is_exact): - comparison = (a.compare(b) < 0) - - if comparison: - return b - else: - return a - - elif a.is_exact: - return self._try_combine_lower_bound_with_exact(b, a) - - elif b.is_exact: - return self._try_combine_lower_bound_with_exact(a, b) - - def _try_combine_upper_bound_with_exact(self, upper, exact): - comparison = upper.compare(exact) - - if (comparison > 0 or - (comparison == 0 and - upper.matcher == Matchers.LESS_THAN_OR_EQUAL)): - return exact - - raise VersionsNotCompatibleException() - - def _try_combine_upper_bound(self, a, b): - if b.is_unbounded: - return a - elif a.is_unbounded: - return b - - if not (a.is_exact or b.is_exact): - comparison = (a.compare(b) > 0) - - if comparison: - return b - else: - return a - - elif a.is_exact: - return self._try_combine_upper_bound_with_exact(b, a) - - elif b.is_exact: - return self._try_combine_upper_bound_with_exact(a, b) - - def reduce(self, other): - start = None - - if(self.start.is_exact and other.start.is_exact): - start = end = self._try_combine_exact(self.start, other.start) - - else: - start = self._try_combine_lower_bound(self.start, other.start) - end = self._try_combine_upper_bound(self.end, other.end) - - if start.compare(end) > 0: - raise VersionsNotCompatibleException() - - return VersionRange(start=start, end=end) - - def __str__(self): - result = [] - - if self.start.is_unbounded and self.end.is_unbounded: - return 'ANY' - - if not self.start.is_unbounded: - result.append(self.start.to_version_string()) - - if not self.end.is_unbounded: - result.append(self.end.to_version_string()) - - return ', '.join(result) - - def to_version_string_pair(self): - to_return = [] - - if not self.start.is_unbounded: - to_return.append(self.start.to_version_string()) - - if not self.end.is_unbounded: - to_return.append(self.end.to_version_string()) - - return to_return - - @dataclass class VersionSpecifier(VersionSpecification): def to_version_string(self, skip_matcher=False): @@ -212,8 +103,8 @@ def __str__(self): return self.to_version_string() def to_range(self): - range_start = UnboundedVersionSpecifier() - range_end = UnboundedVersionSpecifier() + range_start: VersionSpecifier = UnboundedVersionSpecifier() + range_end: VersionSpecifier = UnboundedVersionSpecifier() if self.matcher == Matchers.EXACT: range_start = self @@ -299,6 +190,118 @@ def is_exact(self): return self.matcher == Matchers.EXACT +@dataclass +class VersionRange: + start: VersionSpecifier + end: VersionSpecifier + + def _try_combine_exact(self, a, b): + if a.compare(b) == 0: + return a + else: + raise VersionsNotCompatibleException() + + def _try_combine_lower_bound_with_exact(self, lower, exact): + comparison = lower.compare(exact) + + if (comparison < 0 or + (comparison == 0 and + lower.matcher == Matchers.GREATER_THAN_OR_EQUAL)): + return exact + + raise VersionsNotCompatibleException() + + def _try_combine_lower_bound(self, a, b): + if b.is_unbounded: + return a + elif a.is_unbounded: + return b + + if not (a.is_exact or b.is_exact): + comparison = (a.compare(b) < 0) + + if comparison: + return b + else: + return a + + elif a.is_exact: + return self._try_combine_lower_bound_with_exact(b, a) + + elif b.is_exact: + return self._try_combine_lower_bound_with_exact(a, b) + + def _try_combine_upper_bound_with_exact(self, upper, exact): + comparison = upper.compare(exact) + + if (comparison > 0 or + (comparison == 0 and + upper.matcher == Matchers.LESS_THAN_OR_EQUAL)): + return exact + + raise VersionsNotCompatibleException() + + def _try_combine_upper_bound(self, a, b): + if b.is_unbounded: + return a + elif a.is_unbounded: + return b + + if not (a.is_exact or b.is_exact): + comparison = (a.compare(b) > 0) + + if comparison: + return b + else: + return a + + elif a.is_exact: + return self._try_combine_upper_bound_with_exact(b, a) + + elif b.is_exact: + return self._try_combine_upper_bound_with_exact(a, b) + + def reduce(self, other): + start = None + + if(self.start.is_exact and other.start.is_exact): + start = end = self._try_combine_exact(self.start, other.start) + + else: + start = self._try_combine_lower_bound(self.start, other.start) + end = self._try_combine_upper_bound(self.end, other.end) + + if start.compare(end) > 0: + raise VersionsNotCompatibleException() + + return VersionRange(start=start, end=end) + + def __str__(self): + result = [] + + if self.start.is_unbounded and self.end.is_unbounded: + return 'ANY' + + if not self.start.is_unbounded: + result.append(self.start.to_version_string()) + + if not self.end.is_unbounded: + result.append(self.end.to_version_string()) + + return ', '.join(result) + + def to_version_string_pair(self): + to_return = [] + + if not self.start.is_unbounded: + to_return.append(self.start.to_version_string()) + + if not self.end.is_unbounded: + to_return.append(self.end.to_version_string()) + + return to_return + + class UnboundedVersionSpecifier(VersionSpecifier): def __init__(self, *args, **kwargs): super().__init__( diff --git a/core/dbt/source_config.py b/core/dbt/source_config.py index 1aea8c6f591..828abe284d6 100644 --- a/core/dbt/source_config.py +++ b/core/dbt/source_config.py @@ -1,3 +1,5 @@ +from typing import Dict, Any + import dbt.exceptions from dbt.utils import deep_merge @@ -45,10 +47,10 @@ def __init__(self, active_project, own_project, fqn, node_type): self.AdapterSpecificConfigs = adapter_class.AdapterSpecificConfigs # the config options defined within the model - self.in_model_config = {} + self.in_model_config: Dict[str, Any] = {} def _merge(self, *configs): - merged_config = {} + merged_config: Dict[str, Any] = {} for config in configs: # Do not attempt to deep merge clobber fields config = config.copy() @@ -143,7 +145,7 @@ def __get_as_list(relevant_configs, key): def smart_update(self, mutable_config, new_configs): config_keys = self.ConfigKeys | self.AdapterSpecificConfigs - relevant_configs = { + relevant_configs: Dict[str, Any] = { key: new_configs[key] for key in new_configs if key in config_keys } @@ -172,7 +174,7 @@ def smart_update(self, mutable_config, new_configs): def get_project_config(self, runtime_config): # most configs are overwritten by a more specific config, but pre/post # hooks are appended! - config = {} + config: Dict[str, Any] = {} for k in self.AppendListFields: config[k] = [] for k in self.ExtendDictFields: diff --git a/core/dbt/task/debug.py b/core/dbt/task/debug.py index d40bf9f1533..d8851a740d8 100644 --- a/core/dbt/task/debug.py +++ b/core/dbt/task/debug.py @@ -2,6 +2,7 @@ import os import platform import sys +from typing import Optional, Dict, Any, List from dbt.logger import GLOBAL_LOGGER as logger import dbt.clients.system @@ -76,13 +77,13 @@ def __init__(self, args, config): ) # set by _load_* - self.profile = None + self.profile: Optional[Profile] = None self.profile_fail_details = '' - self.raw_profile_data = None - self.profile_name = None - self.project = None + self.raw_profile_data: Optional[Dict[str, Any]] = None + self.profile_name: Optional[str] = None + self.project: Optional[Project] = None self.project_fail_details = '' - self.messages = [] + self.messages: List[str] = [] @property def project_profile(self): @@ -137,6 +138,7 @@ def _load_project(self): def _profile_found(self): if not self.raw_profile_data: return red('ERROR not found') + assert self.raw_profile_data is not None if self.profile_name in self.raw_profile_data: return green('OK found') else: @@ -147,6 +149,10 @@ def _target_found(self): self.target_name) if not requirements: return red('ERROR not found') + # mypy appeasement, we checked just above + assert self.raw_profile_data is not None + assert self.profile_name is not None + assert self.target_name is not None if self.profile_name not in self.raw_profile_data: return red('ERROR not found') profiles = self.raw_profile_data[self.profile_name]['outputs'] @@ -186,6 +192,10 @@ def _choose_profile_name(self): def _choose_target_name(self): has_raw_profile = (self.raw_profile_data and self.profile_name and self.profile_name in self.raw_profile_data) + # mypy appeasement, we checked just above + assert self.raw_profile_data is not None + assert self.profile_name is not None + if has_raw_profile: raw_profile = self.raw_profile_data[self.profile_name] diff --git a/core/dbt/task/generate.py b/core/dbt/task/generate.py index f8ee2263a5d..97c03a1589a 100644 --- a/core/dbt/task/generate.py +++ b/core/dbt/task/generate.py @@ -12,6 +12,7 @@ TableMetadata, CatalogTable, CatalogResults, Primitive, CatalogKey, StatsItem, StatsDict, ColumnMetadata ) +from dbt.exceptions import InternalException from dbt.include.global_project import DOCS_INDEX_FILE_PATH import dbt.ui.printer import dbt.utils @@ -179,6 +180,10 @@ def _coerce_decimal(value): class GenerateTask(CompileTask): def _get_manifest(self) -> Manifest: + if self.manifest is None: + raise InternalException( + 'manifest should not be None in _get_manifest' + ) return self.manifest def run(self): @@ -195,6 +200,11 @@ def run(self): DOCS_INDEX_FILE_PATH, os.path.join(self.config.target_path, 'index.html')) + if self.manifest is None: + raise InternalException( + 'self.manifest was None in run!' + ) + adapter = get_adapter(self.config) with adapter.connection_named('generate_catalog'): dbt.ui.printer.print_timestamped_line("Building catalog") diff --git a/core/dbt/task/list.py b/core/dbt/task/list.py index 3113d46ecca..d6c5ce33896 100644 --- a/core/dbt/task/list.py +++ b/core/dbt/task/list.py @@ -2,7 +2,7 @@ from dbt.task.runnable import GraphRunnableTask, ManifestTask from dbt.node_types import NodeType -import dbt.exceptions +from dbt.exceptions import RuntimeException, InternalException from dbt.logger import log_manager, GLOBAL_LOGGER as logger @@ -33,11 +33,11 @@ def __init__(self, args, config): self.args.single_threaded = True if self.args.models: if self.args.select: - raise dbt.exceptions.RuntimeException( + raise RuntimeException( '"models" and "select" are mutually exclusive arguments' ) if self.args.resource_types: - raise dbt.exceptions.RuntimeException( + raise RuntimeException( '"models" and "resource_type" are mutually exclusive ' 'arguments' ) @@ -53,6 +53,10 @@ def _iterate_selected_nodes(self): if not nodes: logger.warning('No nodes selected!') return + if self.manifest is None: + raise InternalException( + 'manifest is None in _iterate_selected_nodes' + ) for node in nodes: yield self.manifest.nodes[node] @@ -95,7 +99,7 @@ def run(self): elif output == 'path': generator = self.generate_paths else: - raise dbt.exceptions.InternalException( + raise InternalException( 'Invalid output {}'.format(output) ) for result in generator(): diff --git a/core/dbt/task/rpc/base.py b/core/dbt/task/rpc/base.py index 7f8e425d264..6989756b884 100644 --- a/core/dbt/task/rpc/base.py +++ b/core/dbt/task/rpc/base.py @@ -9,7 +9,9 @@ class RPCTask( ): def __init__(self, args, config, manifest): super().__init__(args, config) - RemoteManifestMethod.__init__(self, args, config, manifest) + RemoteManifestMethod.__init__( + self, args, config, manifest # type: ignore + ) def load_manifest(self): # we started out with a manifest! diff --git a/core/dbt/task/rpc/cli.py b/core/dbt/task/rpc/cli.py index d7d176be4bd..e3bb477c454 100644 --- a/core/dbt/task/rpc/cli.py +++ b/core/dbt/task/rpc/cli.py @@ -36,8 +36,11 @@ def __init__(self, args, config, manifest): def set_config(self, config): super().set_config(config) + if self.task_type is None: + raise InternalException('task type not set for set_config') if issubclass(self.task_type, RemoteManifestMethod): - self.real_task = self.task_type( + task_type: Type[RemoteManifestMethod] = self.task_type + self.real_task = task_type( self.args, self.config, self.manifest ) else: @@ -53,7 +56,10 @@ def set_args(self, params: RPCCliParameters) -> None: self.task_type = self.get_rpc_task_cls() def get_flags(self): - return self.task_type.get_flags(self) + if self.task_type is None: + raise InternalException('task type not set for get_flags') + # this is a kind of egregious hack from a type perspective... + return self.task_type.get_flags(self) # type: ignore def get_rpc_task_cls(self) -> Type[HasCLI]: # This is obnoxious, but we don't have actual access to the TaskManager diff --git a/core/dbt/task/rpc/project_commands.py b/core/dbt/task/rpc/project_commands.py index faf91902c43..0449f306847 100644 --- a/core/dbt/task/rpc/project_commands.py +++ b/core/dbt/task/rpc/project_commands.py @@ -15,8 +15,10 @@ RPCSourceFreshnessParameters, ) from dbt.rpc.method import ( - Parameters, + Parameters, RemoteManifestMethod ) + +from dbt.task.base import BaseTask from dbt.task.compile import CompileTask from dbt.task.freshness import FreshnessTask from dbt.task.generate import GenerateTask @@ -33,6 +35,7 @@ class RPCCommandTask( RPCTask[Parameters], HasCLI[Parameters, RemoteExecutionResult], + BaseTask, ): @staticmethod def _listify( @@ -118,12 +121,22 @@ def get_catalog_results( class RemoteRunOperationTask( - RPCTask[RPCRunOperationParameters], - HasCLI[RPCRunOperationParameters, RemoteRunOperationResult], RunOperationTask, + RemoteManifestMethod[RPCRunOperationParameters, RemoteRunOperationResult], + HasCLI[RPCRunOperationParameters, RemoteRunOperationResult], ): METHOD_NAME = 'run-operation' + def __init__(self, args, config, manifest): + super().__init__(args, config) + RemoteManifestMethod.__init__( + self, args, config, manifest # type: ignore + ) + + def load_manifest(self): + # we started out with a manifest! + pass + def set_args(self, params: RPCRunOperationParameters) -> None: self.args.macro = params.macro self.args.args = params.args @@ -138,8 +151,14 @@ def _runtime_initialize(self): return RunOperationTask._runtime_initialize(self) def handle_request(self) -> RemoteRunOperationResult: - success, _ = RunOperationTask.run(self) - result = RemoteRunOperationResult(logs=[], success=success) + base = RunOperationTask.run(self) + result = RemoteRunOperationResult( + results=base.results, + generated_at=base.generated_at, + logs=[], + success=base.success, + elapsed_time=base.elapsed_time + ) return result def interpret_results(self, results): diff --git a/core/dbt/task/rpc/server.py b/core/dbt/task/rpc/server.py index 2e8ecd15d31..179a8e1ea50 100644 --- a/core/dbt/task/rpc/server.py +++ b/core/dbt/task/rpc/server.py @@ -6,7 +6,7 @@ import os import signal from contextlib import contextmanager -from typing import Iterator, Optional +from typing import Iterator, Optional, List, Type from werkzeug.middleware.dispatcher import DispatcherMiddleware from werkzeug.wrappers import Request, Response @@ -19,7 +19,7 @@ log_manager, ) from dbt.rpc.logger import ServerContext, HTTPRequest, RPCResponse -from dbt.rpc.method import TaskTypes +from dbt.rpc.method import TaskTypes, RemoteMethod from dbt.rpc.response_manager import ResponseManager from dbt.rpc.task_manager import TaskManager from dbt.task.base import ConfiguredTask @@ -73,7 +73,9 @@ def signhup_replace() -> Iterator[bool]: class RPCServerTask(ConfiguredTask): DEFAULT_LOG_FORMAT = 'json' - def __init__(self, args, config, tasks: Optional[TaskTypes] = None): + def __init__( + self, args, config, tasks: Optional[List[Type[RemoteMethod]]] = None + ) -> None: if os.name == 'nt': raise RuntimeException( 'The dbt RPC server is not supported on windows' diff --git a/core/dbt/task/rpc/sql_commands.py b/core/dbt/task/rpc/sql_commands.py index 4b072a4eb4d..c9e4c09b971 100644 --- a/core/dbt/task/rpc/sql_commands.py +++ b/core/dbt/task/rpc/sql_commands.py @@ -8,7 +8,7 @@ from dbt.compilation import compile_manifest, compile_node from dbt.contracts.rpc import RPCExecParameters from dbt.contracts.rpc import RemoteExecutionResult -from dbt.exceptions import RPCKilledException +from dbt.exceptions import RPCKilledException, InternalException from dbt.logger import GLOBAL_LOGGER as logger from dbt.parser.results import ParseResult from dbt.parser.rpc import RPCCallParser, RPCMacroParser @@ -75,6 +75,10 @@ def _compile_ancestors(self, unique_id: str): # this just gets a transitive closure of the nodes. We could build a # special GraphQueue around this, but we do them all in the main thread # so we only care about preserving dependency order anyway + if self.linker is None or self.manifest is None: + raise InternalException( + 'linker and manifest not set in _compile_ancestors' + ) sorted_ancestors = self.linker.sorted_ephemeral_ancestors( self.manifest, unique_id, @@ -90,6 +94,11 @@ def _compile_ancestors(self, unique_id: str): ) def _get_exec_node(self): + if self.manifest is None: + raise InternalException( + 'manifest not set in _get_exec_node' + ) + results = ParseResult.rpc() macro_overrides = {} macros = self.args.macros @@ -107,18 +116,18 @@ def _get_exec_node(self): root_project=self.config, macro_manifest=self.manifest, ) - node = rpc_parser.parse_remote(sql, self.args.name) + rpc_node = rpc_parser.parse_remote(sql, self.args.name) self.manifest = ParserUtils.add_new_refs( manifest=self.manifest, config=self.config, - node=node, + node=rpc_node, macros=macro_overrides ) # don't write our new, weird manifest! self.linker = compile_manifest(self.config, self.manifest, write=False) - self._compile_ancestors(node.unique_id) - return node + self._compile_ancestors(rpc_node.unique_id) + return rpc_node def _raise_set_error(self): if self._raise_next_tick is not None: diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index bed80c0397c..1cd03913a2c 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -1,6 +1,6 @@ import functools import time -from typing import List, Dict, Any, Set, Tuple, Optional +from typing import List, Dict, Any, Iterable, Set, Tuple, Optional from dbt.logger import ( GLOBAL_LOGGER as logger, @@ -10,6 +10,7 @@ TimestampNamed, DbtModelState, ) +from dbt.exceptions import InternalException from dbt.node_types import NodeType, RunHookType from dbt.node_runners import ModelRunner @@ -24,9 +25,9 @@ get_counts from dbt.compilation import compile_node +from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.graph.parsed import ParsedHookNode from dbt.task.compile import CompileTask -from dbt.utils import get_nodes_by_tags class Timer: @@ -61,6 +62,20 @@ def _hook_list() -> List[ParsedHookNode]: return [] +def get_hooks_by_tags( + nodes: Iterable[CompileResultNode], + match_tags: Set[str], +) -> List[ParsedHookNode]: + matched_nodes = [] + for node in nodes: + if not isinstance(node, ParsedHookNode): + continue + node_tags = node.tags + if len(set(node_tags) & match_tags): + matched_nodes.append(node) + return matched_nodes + + class RunTask(CompileTask): def __init__(self, args, config): super().__init__(args, config) @@ -84,7 +99,7 @@ def get_hook_sql(self, adapter, hook, idx, num_hooks, extra_context): hook_obj = get_hook(statement, index=hook_index) return hook_obj.sql or '' - def _hook_keyfunc(self, hook: ParsedHookNode): + def _hook_keyfunc(self, hook: ParsedHookNode) -> Tuple[str, Optional[int]]: package_name = hook.package_name if package_name == self.config.project_name: package_name = BiggestName('') @@ -93,9 +108,15 @@ def _hook_keyfunc(self, hook: ParsedHookNode): def get_hooks_by_type( self, hook_type: RunHookType ) -> List[ParsedHookNode]: + + if self.manifest is None: + raise InternalException( + 'self.manifest was None in get_hooks_by_type' + ) + nodes = self.manifest.nodes.values() # find all hooks defined in the manifest (could be multiple projects) - hooks = get_nodes_by_tags(nodes, {hook_type}, NodeType.Operation) + hooks: List[ParsedHookNode] = get_hooks_by_tags(nodes, {hook_type}) hooks.sort(key=self._hook_keyfunc) return hooks diff --git a/core/dbt/task/run_operation.py b/core/dbt/task/run_operation.py index ed280acbe84..6b318fa005c 100644 --- a/core/dbt/task/run_operation.py +++ b/core/dbt/task/run_operation.py @@ -1,6 +1,13 @@ +from datetime import datetime +from typing import Dict, Any + +import agate + from dbt.logger import GLOBAL_LOGGER as logger from dbt.task.runnable import ManifestTask from dbt.adapters.factory import get_adapter +from dbt.contracts.results import RunOperationResult +from dbt.exceptions import InternalException import dbt import dbt.utils @@ -17,14 +24,16 @@ def _get_macro_parts(self): return package_name, macro_name - def _get_kwargs(self): + def _get_kwargs(self) -> Dict[str, Any]: return dbt.utils.parse_cli_vars(self.args.args) - def compile_manifest(self): + def compile_manifest(self) -> None: # skip building a linker, but do make sure to build the flat graph + if self.manifest is None: + raise InternalException('manifest was None in compile_manifest') self.manifest.build_flat_graph() - def _run_unsafe(self): + def _run_unsafe(self) -> agate.Table: adapter = get_adapter(self.config) package_name, macro_name = self._get_macro_parts() @@ -41,27 +50,34 @@ def _run_unsafe(self): return res - def run(self): + def run(self) -> RunOperationResult: + start = datetime.utcnow() self._runtime_initialize() try: - result = self._run_unsafe() + self._run_unsafe() except dbt.exceptions.Exception as exc: logger.error( 'Encountered an error while running operation: {}' .format(exc) ) logger.debug('', exc_info=True) - return False, None + success = False except Exception as exc: logger.error( 'Encountered an uncaught exception while running operation: {}' .format(exc) ) logger.debug('', exc_info=True) - return False, None + success = False else: - return True, result + success = True + end = datetime.utcnow() + return RunOperationResult( + results=[], + generated_at=end, + elapsed_time=(end - start).total_seconds(), + success=success + ) def interpret_results(self, results): - success, _ = results - return success + return results.success diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 394d8f09ada..ad508649eb3 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -2,6 +2,7 @@ import time from datetime import datetime from multiprocessing.dummy import Pool as ThreadPool +from typing import Optional, Dict, List, Set, Tuple from dbt.task.base import ConfiguredTask from dbt.adapters.factory import get_adapter @@ -16,7 +17,14 @@ NodeCount, ) from dbt.compilation import compile_manifest + +from dbt.contracts.graph.compiled import CompileResultNode +from dbt.contracts.graph.manifest import Manifest from dbt.contracts.results import ExecutionResult +from dbt.exceptions import ( + InternalException, NotImplementedException, RuntimeException +) +from dbt.linker import Linker, GraphQueue from dbt.perf_utils import get_full_manifest import dbt.exceptions @@ -39,14 +47,18 @@ def write_manifest(config, manifest): class ManifestTask(ConfiguredTask): def __init__(self, args, config): super().__init__(args, config) - self.manifest = None - self.linker = None + self.manifest: Optional[Manifest] = None + self.linker: Optional[Linker] = None def load_manifest(self): self.manifest = get_full_manifest(self.config) write_manifest(self.config, self.manifest) def compile_manifest(self): + if self.manifest is None: + raise InternalException( + 'compile_manifest called before manifest was loaded' + ) self.linker = compile_manifest(self.config, self.manifest) self.manifest.build_flat_graph() @@ -58,11 +70,11 @@ def _runtime_initialize(self): class GraphRunnableTask(ManifestTask): def __init__(self, args, config): super().__init__(args, config) - self.job_queue = None - self._flattened_nodes = None + self.job_queue: Optional[GraphQueue] = None + self._flattened_nodes: Optional[List[CompileResultNode]] = None - self.run_count = 0 - self.num_nodes = None + self.run_count: int = 0 + self.num_nodes: int = 0 self.node_results = [] self._skipped_children = {} self._raise_next_tick = None @@ -71,6 +83,11 @@ def index_offset(self, value: int) -> int: return value def select_nodes(self): + if self.manifest is None or self.linker is None: + raise InternalException( + 'select_nodes called before manifest and linker were loaded' + ) + selector = dbt.graph.selector.NodeSelector( self.linker.graph, self.manifest ) @@ -79,6 +96,10 @@ def select_nodes(self): def _runtime_initialize(self): super()._runtime_initialize() + if self.manifest is None or self.linker is None: + raise InternalException( + '_runtime_initialize never loaded the manifest and linker!' + ) selected_nodes = self.select_nodes() self.job_queue = self.linker.as_graph_queue(self.manifest, selected_nodes) @@ -97,10 +118,10 @@ def raise_on_first_error(self): return False def build_query(self): - raise dbt.exceptions.NotImplementedException('Not Implemented') + raise NotImplementedException('Not Implemented') def get_runner_type(self): - raise dbt.exceptions.NotImplementedException('Not Implemented') + raise NotImplementedException('Not Implemented') def result_path(self): return os.path.join(self.config.target_path, RESULT_FILE_NAME) @@ -128,7 +149,7 @@ def call_runner(self, runner): with startctx, extended_metadata: logger.debug('Began running node {}'.format( runner.node.unique_id)) - status = 'error' # we must have an error if we don't see this + status: Dict[str, str] = {'state': 'error'} try: result = runner.run_with_hooks(self.manifest) status = runner.get_result_status(result) @@ -161,16 +182,26 @@ def _submit(self, pool, args, callback): def _raise_set_error(self): if self._raise_next_tick is not None: - raise dbt.exceptions.RuntimeException(self._raise_next_tick) + raise RuntimeException(self._raise_next_tick) def run_queue(self, pool): """Given a pool, submit jobs from the queue to the pool. """ + if self.job_queue is None: + raise InternalException( + 'Got to run_queue with no job queue set' + ) + def callback(result): """Note: mark_done, at a minimum, must happen here or dbt will deadlock during ephemeral result error handling! """ self._handle_result(result) + + if self.job_queue is None: + raise InternalException( + 'Got to run_queue callback with no job queue set' + ) self.job_queue.mark_done(result.node.unique_id) while not self.job_queue.empty(): @@ -193,7 +224,7 @@ def callback(result): return def _handle_result(self, result): - """Mark the result as completed, insert the `CompiledResultNode` into + """Mark the result as completed, insert the `CompileResultNode` into the manifest, and mark any descendants (potentially with a 'cause' if the result was an ephemeral model) as skipped. """ @@ -202,6 +233,10 @@ def _handle_result(self, result): self.node_results.append(result) node = result.node + + if self.manifest is None: + raise InternalException('manifest was None in _handle_result') + self.manifest.update_node(node) if result.error is not None: @@ -262,6 +297,8 @@ def execute_nodes(self): return self.node_results def _mark_dependent_errors(self, node_id, result, cause): + if self.linker is None: + raise InternalException('linker is None in _mark_dependent_errors') for dep_node_id in self.linker.get_dependent_nodes(node_id): self._skipped_children[dep_node_id] = cause @@ -304,6 +341,11 @@ def run(self): """ self._runtime_initialize() + if self._flattened_nodes is None: + raise InternalException( + 'after _runtime_initialize, _flattened_nodes was still None' + ) + if len(self._flattened_nodes) == 0: logger.warning("WARNING: Nothing to do. Try checking your model " "configs and model specification args") @@ -333,7 +375,10 @@ def interpret_results(self, results): return len(failures) == 0 def get_model_schemas(self, selected_uids): - schemas = set() + if self.manifest is None: + raise InternalException('manifest was None in get_model_schemas') + + schemas: Set[Tuple[str, str]] = set() for node in self.manifest.nodes.values(): if node.unique_id not in selected_uids: continue @@ -346,7 +391,7 @@ def create_schemas(self, adapter, selected_uids): required_schemas = self.get_model_schemas(selected_uids) required_databases = set(db for db, _ in required_schemas) - existing_schemas_lowered = set() + existing_schemas_lowered: Set[Tuple[str, str]] = set() for db in required_databases: existing_schemas_lowered.update( (db.lower(), s.lower()) for s in adapter.list_schemas(db)) diff --git a/core/dbt/task/serve.py b/core/dbt/task/serve.py index d92f37312cb..25666f5f572 100644 --- a/core/dbt/task/serve.py +++ b/core/dbt/task/serve.py @@ -25,10 +25,11 @@ def run(self): ) logger.info("Press Ctrl+C to exit.\n\n") - httpd = TCPServer( + # mypy doesn't think SimpleHTTPRequestHandler is ok here, but it is + httpd = TCPServer( # type: ignore ('0.0.0.0', port), - SimpleHTTPRequestHandler - ) + SimpleHTTPRequestHandler # type: ignore + ) # type: ignore try: webbrowser.open_new_tab('http://127.0.0.1:{}'.format(port)) diff --git a/core/dbt/tracking.py b/core/dbt/tracking.py index 92a3888e871..c1f45a5d82c 100644 --- a/core/dbt/tracking.py +++ b/core/dbt/tracking.py @@ -1,3 +1,5 @@ +from typing import Optional + from dbt.logger import GLOBAL_LOGGER as logger from dbt import version as dbt_version from snowplow_tracker import Subject, Tracker, Emitter, logger as sp_logger @@ -55,8 +57,6 @@ def http_get(self, payload): emitter = TimeoutEmitter() tracker = Tracker(emitter, namespace="cf", app_id="dbt") -active_user = None - class User: @@ -119,6 +119,9 @@ def get_cookie(self): return user +active_user: Optional[User] = None + + def get_run_type(args): return 'regular' @@ -239,6 +242,8 @@ def track_invocation_start(config=None, args=None): def track_model_run(options): context = [SelfDescribingJson(RUN_MODEL_SPEC, options)] + assert active_user is not None, \ + 'Cannot track model runs when active user is None' track( active_user, @@ -251,6 +256,8 @@ def track_model_run(options): def track_rpc_request(options): context = [SelfDescribingJson(RPC_REQUEST_SPEC, options)] + assert active_user is not None, \ + 'Cannot track rpc requests when active user is None' track( active_user, @@ -263,6 +270,9 @@ def track_rpc_request(options): def track_package_install(options): context = [SelfDescribingJson(PACKAGE_INSTALL_SPEC, options)] + assert active_user is not None, \ + 'Cannot track package installs when active user is None' + track( active_user, category="dbt", @@ -282,6 +292,10 @@ def track_invocation_end( get_platform_context(), get_dbt_env_context() ] + + assert active_user is not None, \ + 'Cannot track invocation end when active user is None' + track( active_user, category="dbt", @@ -294,6 +308,8 @@ def track_invocation_end( def track_invalid_invocation( config=None, args=None, result_type=None ): + assert active_user is not None, \ + 'Cannot track invalid invocations when active user is None' user = active_user invocation_context = get_invocation_invalid_context( @@ -344,7 +360,8 @@ def __init__(self): super().__init__() def process(self, record): - record.extra.update({ - "run_started_at": active_user.run_started_at.isoformat(), - "invocation_id": active_user.invocation_id, - }) + if active_user is not None: + record.extra.update({ + "run_started_at": active_user.run_started_at.isoformat(), + "invocation_id": active_user.invocation_id, + }) diff --git a/core/dbt/utils.py b/core/dbt/utils.py index d8dd6547d72..562b678d94d 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -9,9 +9,8 @@ import os from enum import Enum from typing import ( - Tuple, Type, Any, Optional, TypeVar, Dict, Iterable, Set, List, Union + Tuple, Type, Any, Optional, TypeVar, Dict, Union, Callable ) -from typing_extensions import Protocol import dbt.exceptions @@ -238,8 +237,14 @@ def deep_merge_item(destination, key, value): destination[key] = value -def _deep_map(func, value, keypath): - atomic_types = (int, float, str, type(None), bool) +def _deep_map( + func: Callable[[Any, Tuple[Union[str, int], ...]], Any], + value: Any, + keypath: Tuple[Union[str, int], ...], +) -> Any: + atomic_types: Tuple[Type[Any], ...] = (int, float, str, type(None), bool) + + ret: Any if isinstance(value, list): ret = [ @@ -248,13 +253,14 @@ def _deep_map(func, value, keypath): ] elif isinstance(value, dict): ret = { - k: _deep_map(func, v, (keypath + (k,))) + k: _deep_map(func, v, (keypath + (str(k),))) for k, v in value.items() } elif isinstance(value, atomic_types): ret = func(value, keypath) else: - ok_types = (list, dict) + atomic_types + container_types: Tuple[Type[Any], ...] = (list, dict) + ok_types = container_types + atomic_types raise dbt.exceptions.DbtConfigError( 'in _deep_map, expected one of {!r}, got {!r}' .format(ok_types, type(value)) @@ -317,24 +323,6 @@ def get_pseudo_hook_path(hook_name): return os.path.join(*path_parts) -class _Tagged(Protocol): - tags: Iterable[str] - - -Tagged = TypeVar('Tagged', bound=_Tagged) - - -def get_nodes_by_tags( - nodes: Iterable[Tagged], match_tags: Set[str], resource_type: NodeType -) -> List[Tagged]: - matched_nodes = [] - for node in nodes: - node_tags = node.tags - if len(set(node_tags) & match_tags): - matched_nodes.append(node) - return matched_nodes - - def md5(string): return hashlib.md5(string.encode('utf-8')).hexdigest() @@ -421,11 +409,11 @@ def invalid_source_fail_unless_test(node, target_name, target_table_name): target_table_name) -def parse_cli_vars(var_string): +def parse_cli_vars(var_string: str) -> Dict[str, Any]: try: cli_vars = yaml_helper.load_yaml_text(var_string) var_type = type(cli_vars) - if var_type == dict: + if var_type is dict: return cli_vars else: type_name = var_type.__name__ @@ -484,7 +472,9 @@ def default(self, obj): return str(obj) -def translate_aliases(kwargs, aliases): +def translate_aliases( + kwargs: Dict[str, Any], aliases: Dict[str, str] +) -> Dict[str, Any]: """Given a dict of keyword arguments and a dict mapping aliases to their canonical values, canonicalize the keys in the kwargs dict. @@ -492,7 +482,7 @@ def translate_aliases(kwargs, aliases): canonical key. :raises: `AliasException`, if a canonical key is defined more than once. """ - result = {} + result: Dict[str, Any] = {} for given_key, value in kwargs.items(): canonical_key = aliases.get(given_key, given_key) diff --git a/dev_requirements.txt b/dev_requirements.txt index a5c31a1a5d4..89f4b147a3a 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -7,4 +7,4 @@ tox==2.5.0 ipdb pytest-xdist>=1.28.0,<2 flaky>=3.5.3,<4 -mypy==0.720 +mypy==0.761 diff --git a/mypy.ini b/mypy.ini index 51fada1b1dc..f9b2602b536 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,3 @@ [mypy] -mypy_path = ./third-party-stubs +mypy_path = ./third-party-stubs,./core namespace_packages = True diff --git a/test/integration/022_bigquery_test/macros/test_int_inference.sql b/test/integration/022_bigquery_test/macros/test_int_inference.sql index 852accbb0e7..a1ab1c8a900 100644 --- a/test/integration/022_bigquery_test/macros/test_int_inference.sql +++ b/test/integration/022_bigquery_test/macros/test_int_inference.sql @@ -1,4 +1,11 @@ +{% macro assert_eq(value, expected, msg) %} + {% if value != expected %} + {% do exceptions.raise_compiler_error(msg ~ value) %} + {% endif %} +{% endmacro %} + + {% macro test_int_inference() %} {% set sql %} @@ -8,6 +15,22 @@ 2 as int_2 {% endset %} - {% do return(run_query(sql)) %} + {% set result = run_query(sql) %} + {% do assert_eq((result | length), 1, 'expected 1 result, got ') %} + {% set actual_0 = result[0]['int_0'] %} + {% set actual_1 = result[0]['int_1'] %} + {% set actual_2 = result[0]['int_2'] %} + + {% do assert_eq(actual_0, 0, 'expected expected actual_0 to be 0, it was ') %} + {% do assert_eq((actual_0 | string), '0', 'expected string form of actual_0 to be 0, it was ') %} + {% do assert_eq((actual_0 * 2), 0, 'expected actual_0 * 2 to be 0, it was ') %} {# not 00 #} + + {% do assert_eq(actual_1, 1, 'expected actual_1 to be 1, it was ') %} + {% do assert_eq((actual_1 | string), '1', 'expected string form of actual_1 to be 1, it was ') %} + {% do assert_eq((actual_1 * 2), 2, 'expected actual_1 * 2 to be 2, it was ') %} {# not 11 #} + + {% do assert_eq(actual_2, 2, 'expected actual_2 to be 2, it was ') %} + {% do assert_eq((actual_2 | string), '2', 'expected string form of actual_2 to be 2, it was ') %} + {% do assert_eq((actual_2 * 2), 4, 'expected actual_2 * 2 to be 4, it was ') %} {# not 22 #} {% endmacro %} diff --git a/test/integration/022_bigquery_test/test_bigquery_query_results.py b/test/integration/022_bigquery_test/test_bigquery_query_results.py index b7d6109f91e..55dc670f6ee 100644 --- a/test/integration/022_bigquery_test/test_bigquery_query_results.py +++ b/test/integration/022_bigquery_test/test_bigquery_query_results.py @@ -1,5 +1,5 @@ from test.integration.base import DBTIntegrationTest, use_profile -import json + class TestBaseBigQueryResults(DBTIntegrationTest): @@ -19,21 +19,5 @@ def project_config(self): @use_profile('bigquery') def test__bigquery_type_inference(self): - _, test_results = self.run_dbt(['run-operation', 'test_int_inference']) - self.assertEqual(len(test_results), 1) - - actual_0 = test_results.rows[0]['int_0'] - actual_1 = test_results.rows[0]['int_1'] - actual_2 = test_results.rows[0]['int_2'] - - self.assertEqual(actual_0, 0) - self.assertEqual(str(actual_0), '0') - self.assertEqual(actual_0 * 2, 0) # not 00 - - self.assertEqual(actual_1, 1) - self.assertEqual(str(actual_1), '1') - self.assertEqual(actual_1 * 2, 2) # not 11 - - self.assertEqual(actual_2, 2) - self.assertEqual(str(actual_2), '2') - self.assertEqual(actual_2 * 2, 4) # not 22 + result = self.run_dbt(['run-operation', 'test_int_inference']) + self.assertTrue(result.success) diff --git a/test/unit/test_semver.py b/test/unit/test_semver.py index 05cde5b450f..6162cded2b3 100644 --- a/test/unit/test_semver.py +++ b/test/unit/test_semver.py @@ -26,7 +26,7 @@ def assertVersionSetResult(self, inputs, output_range): expected = create_range(*output_range) for permutation in itertools.permutations(inputs): - self.assertDictEqual( + self.assertEqual( reduce_versions(*permutation), expected) diff --git a/third-party-stubs/agate/__init__.pyi b/third-party-stubs/agate/__init__.pyi index b1948fbb394..7f21d3badf7 100644 --- a/third-party-stubs/agate/__init__.pyi +++ b/third-party-stubs/agate/__init__.pyi @@ -1,8 +1,16 @@ from collections import Sequence -from typing import Any, Optional, Callable +from typing import Any, Optional, Callable, Iterable, Dict, Union from . import data_types as data_types +from .data_types import ( + Text as Text, + Number as Number, + Boolean as Boolean, + DateTime as DateTime, + Date as Date, + TimeDelta as TimeDelta, +) class MappedSequence(Sequence): @@ -43,6 +51,12 @@ class Table: def print_csv(self, **kwargs: Any) -> None: ... def print_json(self, **kwargs: Any) -> None: ... def where(self, test: Callable[[Row], bool]) -> 'Table': ... + def select(self, key: Union[Iterable[str], str]) -> 'Table': ... + # these definitions are much narrower than what's actually accepted + @classmethod + def from_object(cls, obj: Iterable[Dict[str, Any]], *, column_types: Optional['TypeTester'] = None) -> 'Table': ... + @classmethod + def from_csv(cls, path: Iterable[str], *, column_types: Optional['TypeTester'] = None) -> 'Table': ... class TypeTester: diff --git a/third-party-stubs/jsonrpc/manager.pyi b/third-party-stubs/jsonrpc/manager.pyi index 589ccab6d88..4ddc714662a 100644 --- a/third-party-stubs/jsonrpc/manager.pyi +++ b/third-party-stubs/jsonrpc/manager.pyi @@ -7,7 +7,7 @@ from .jsonrpc import JSONRPCRequest from .jsonrpc1 import JSONRPC10Response from .jsonrpc2 import JSONRPC20BatchRequest, JSONRPC20BatchResponse, JSONRPC20Response from .utils import is_invalid_params -from typing import Any +from typing import Any, List logger: Any @@ -17,3 +17,5 @@ class JSONRPCResponseManager: def handle(cls, request_str: Any, dispatcher: Any): ... @classmethod def handle_request(cls, request: Any, dispatcher: Any): ... + @classmethod + def _get_responses(cls, requests: List[Any], dispatcher: Any): ... diff --git a/third-party-stubs/snowplow_tracker/__init__.pyi b/third-party-stubs/snowplow_tracker/__init__.pyi index 5dff3636134..de00975641b 100644 --- a/third-party-stubs/snowplow_tracker/__init__.pyi +++ b/third-party-stubs/snowplow_tracker/__init__.pyi @@ -1,6 +1,7 @@ import logging from typing import Union, Optional, List, Any, Dict + class Subject: def __init__(self) -> None: ... def set_platform(self, value: Any): ... @@ -20,7 +21,9 @@ logger: logging.Logger class Emitter: + endpoint: str def __init__(self, endpoint: str, protocol: str = ..., port: Optional[int] = ..., method: str = ..., buffer_size: Optional[int] = ..., on_success: Optional[Any] = ..., on_failure: Optional[Any] = ..., byte_limit: Optional[int] = ...) -> None: ... + def is_good_status_code(self, status_code: int) -> bool: ... class Tracker: @@ -31,6 +34,9 @@ class Tracker: encode_base64: bool = ... def __init__(self, emitters: Union[List[Any], Any], subject: Optional[Subject] = ..., namespace: Optional[str] = ..., app_id: Optional[str] = ..., encode_base64: bool = ...) -> None: ... + def set_subject(self, subject: Optional[Subject]): ... + def track_struct_event(self, category: str, action: str, label: Optional[str] = None, property_: Optional[str] = None, value: Optional[float] = None, context: Optional[List[Any]] = None, tstamp: Optional[Any] = None): ... + def flush(self, asynchronous: bool = False): ... class SelfDescribingJson: