Skip to content

Commit

Permalink
Fix mypy checking
Browse files Browse the repository at this point in the history
Make mypy check our nested namespace packages by putting dbt in the mypy_path.
Fix a number of exposed mypy/type checker complaints. The checker mostly
passes now even if you add `--check-untyped-defs`, though there are a couple lingering issues so I'll leave that out of CI
Change the return type of RunOperation a bit - adds a couple fields to appease mypy

Also, bump the mypy version (it catches a few more issues).
  • Loading branch information
Jacob Beck committed Jan 27, 2020
1 parent 4e23e7d commit 9cc7a7a
Show file tree
Hide file tree
Showing 59 changed files with 788 additions and 466 deletions.
5 changes: 3 additions & 2 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import os
from multiprocessing import RLock
# multiprocessing.RLock is a function returning this type
from multiprocessing.synchronize import RLock
from threading import get_ident
from typing import (
Dict, Tuple, Hashable, Optional, ContextManager, List
Expand Down Expand Up @@ -144,7 +145,7 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection:
'Opening a new connection, currently in state {}'
.format(conn.state)
)
conn.handle = LazyHandle(type(self))
conn.handle = LazyHandle(self.open)

conn.name = conn_name
return conn
Expand Down
23 changes: 16 additions & 7 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from typing import (
Optional, Tuple, Callable, Container, FrozenSet, Type, Dict, Any, List,
Mapping, Iterator, Union
Mapping, Iterator, Union, Set
)

import agate
Expand Down Expand Up @@ -125,15 +125,18 @@ def add(self, relation: BaseRelation):
key = relation.information_schema_only()
if key not in self:
self[key] = set()
self[key].add(relation.schema.lower())
lowered: Optional[str] = None
if relation.schema is not None:
lowered = relation.schema.lower()
self[key].add(lowered)

def search(self):
for information_schema_name, schemas in self.items():
for schema in schemas:
yield information_schema_name, schema

def schemas_searched(self):
result = set()
result: Set[Tuple[str, str]] = set()
for information_schema_name, schemas in self.items():
result.update(
(information_schema_name.database, schema)
Expand Down Expand Up @@ -907,13 +910,17 @@ def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:

@available
@classmethod
def convert_type(cls, agate_table, col_idx):
def convert_type(
cls, agate_table: agate.Table, col_idx: int
) -> Optional[str]:
return cls.convert_agate_type(agate_table, col_idx)

@classmethod
def convert_agate_type(cls, agate_table, col_idx):
agate_type = agate_table.column_types[col_idx]
conversions = [
def convert_agate_type(
cls, agate_table: agate.Table, col_idx: int
) -> Optional[str]:
agate_type: Type = agate_table.column_types[col_idx]
conversions: List[Tuple[Type, Callable[..., str]]] = [
(agate.Text, cls.convert_text_type),
(agate.Number, cls.convert_number_type),
(agate.Boolean, cls.convert_boolean_type),
Expand All @@ -925,6 +932,8 @@ def convert_agate_type(cls, agate_table, col_idx):
if isinstance(agate_type, agate_cls):
return func(agate_table, col_idx)

return None

###
# Operations involving the manifest
###
Expand Down
22 changes: 16 additions & 6 deletions core/dbt/adapters/base/meta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from functools import wraps
from typing import Callable, Optional, Any, FrozenSet, Dict
from typing import Callable, Optional, Any, FrozenSet, Dict, Set

from dbt.deprecations import warn, renamed_method

Expand Down Expand Up @@ -86,16 +86,26 @@ def parse_list(self, func: Callable) -> Callable:


class AdapterMeta(abc.ABCMeta):
_available_: FrozenSet[str]
_parse_replacements_: Dict[str, Callable]

def __new__(mcls, name, bases, namespace, **kwargs):
cls = super().__new__(mcls, name, bases, namespace, **kwargs)
# mypy does not like the `**kwargs`. But `ABCMeta` itself takes
# `**kwargs` in its argspec here (and passes them to `type.__new__`.
# I'm not sure there is any benefit to it after poking around a bit,
# but having it doesn't hurt on the python side (and omitting it could
# hurt for obscure metaclass reasons, for all I know)
cls = abc.ABCMeta.__new__( # type: ignore
mcls, name, bases, namespace, **kwargs
)

# this is very much inspired by ABCMeta's own implementation

# dict mapping the method name to whether the model name should be
# injected into the arguments. All methods in here are exposed to the
# context.
available = set()
replacements = {}
available: Set[str] = set()
replacements: Dict[str, Any] = {}

# collect base class data first
for base in bases:
Expand All @@ -110,7 +120,7 @@ def __new__(mcls, name, bases, namespace, **kwargs):
if parse_replacement is not None:
replacements[name] = parse_replacement

cls._available_: FrozenSet[str] = frozenset(available)
cls._available_ = frozenset(available)
# should this be a namedtuple so it will be immutable like _available_?
cls._parse_replacements_: Dict[str, Callable] = replacements
cls._parse_replacements_ = replacements
return cls
16 changes: 8 additions & 8 deletions core/dbt/adapters/base/query_headers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from threading import local
from typing import Optional, Callable
from typing import Optional, Callable, Dict, Any

from dbt.clients.jinja import QueryStringGenerator

Expand Down Expand Up @@ -56,7 +56,7 @@ def add(self, sql: str) -> str:
return '/* {} */\n{}'.format(self.query_comment.strip(), sql)

def set(self, comment: Optional[str]):
if '*/' in comment:
if isinstance(comment, str) and '*/' in comment:
# tell the user "no" so they don't hurt themselves by writing
# garbage
raise RuntimeException(
Expand All @@ -65,7 +65,7 @@ def set(self, comment: Optional[str]):
self.query_comment = comment


QueryStringFunc = Callable[[str, Optional[CompileResultNode]], str]
QueryStringFunc = Callable[[str, Optional[NodeWrapper]], str]


class QueryStringSetter:
Expand All @@ -77,13 +77,14 @@ def __init__(self, config: AdapterRequiredConfig):
self.generator: QueryStringFunc = lambda name, model: ''
# if the comment value was None or the empty string, just skip it
if comment_macro:
assert isinstance(comment_macro, str)
macro = '\n'.join((
'{%- macro query_comment_macro(connection_name, node) -%}',
self._get_comment_macro(),
comment_macro,
'{% endmacro %}'
))
ctx = self._get_context()
self.generator: QueryStringFunc = QueryStringGenerator(macro, ctx)
self.generator = QueryStringGenerator(macro, ctx)
self.comment = _QueryComment(None)
self.reset()

Expand All @@ -105,10 +106,9 @@ def reset(self):
self.set('master', None)

def set(self, name: str, node: Optional[CompileResultNode]):
wrapped: Optional[NodeWrapper] = None
if node is not None:
wrapped = NodeWrapper(node)
else:
wrapped = None
comment_str = self.generator(name, wrapped)
self.comment.set(comment_str)

Expand All @@ -127,5 +127,5 @@ def _get_comment_macro(self):
else:
return super()._get_comment_macro()

def _get_context(self):
def _get_context(self) -> Dict[str, Any]:
return QueryHeaderContext(self.config).to_dict(self.manifest.macros)
9 changes: 6 additions & 3 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,13 @@ def External(cls) -> str:
return str(RelationType.External)

@classproperty
def RelationType(cls) -> Type[RelationType]:
def get_relation_type(cls) -> Type[RelationType]:
return RelationType


Info = TypeVar('Info', bound='InformationSchema')


@dataclass(frozen=True, eq=False, repr=False)
class InformationSchema(BaseRelation):
information_schema_view: Optional[str] = None
Expand Down Expand Up @@ -470,10 +473,10 @@ def get_quote_policy(

@classmethod
def from_relation(
cls: Self,
cls: Type[Info],
relation: BaseRelation,
information_schema_view: Optional[str],
) -> Self:
) -> Info:
include_policy = cls.get_include_policy(
relation, information_schema_view
)
Expand Down
40 changes: 24 additions & 16 deletions core/dbt/adapters/cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import namedtuple
from copy import deepcopy
from typing import List, Iterable, Optional
from typing import List, Iterable, Optional, Dict, Set, Tuple, Any
import threading

from dbt.logger import CACHE_LOGGER as logger
Expand Down Expand Up @@ -177,20 +177,24 @@ class RelationsCache:
The adapters also hold this lock while filling the cache.
:attr Set[str] schemas: The set of known/cached schemas, all lowercased.
"""
def __init__(self):
self.relations = {}
def __init__(self) -> None:
self.relations: Dict[_ReferenceKey, _CachedRelation] = {}
self.lock = threading.RLock()
self.schemas = set()
self.schemas: Set[Tuple[Optional[str], Optional[str]]] = set()

def add_schema(self, database: str, schema: str):
def add_schema(
self, database: Optional[str], schema: Optional[str],
) -> None:
"""Add a schema to the set of known schemas (case-insensitive)
:param database: The database name to add.
:param schema: The schema name to add.
"""
self.schemas.add((_lower(database), _lower(schema)))

def drop_schema(self, database: str, schema: str):
def drop_schema(
self, database: Optional[str], schema: Optional[str],
) -> None:
"""Drop the given schema and remove it from the set of known schemas.
Then remove all its contents (and their dependents, etc) as well.
Expand All @@ -208,21 +212,21 @@ def drop_schema(self, database: str, schema: str):
# handle a drop_schema race by using discard() over remove()
self.schemas.discard(key)

def update_schemas(self, schemas: Iterable[str]):
def update_schemas(self, schemas: Iterable[Tuple[Optional[str], str]]):
"""Add multiple schemas to the set of known schemas (case-insensitive)
:param schemas: An iterable of the schema names to add.
"""
self.schemas.update((_lower(d), _lower(s)) for (d, s) in schemas)
self.schemas.update((_lower(d), s.lower()) for (d, s) in schemas)

def __contains__(self, schema_id):
def __contains__(self, schema_id: Tuple[Optional[str], str]):
"""A schema is 'in' the relations cache if it is in the set of cached
schemas.
:param Tuple[str, str] schema: The db name and schema name to look up.
:param schema_id: The db name and schema name to look up.
"""
db, schema = schema_id
return (_lower(db), _lower(schema)) in self.schemas
return (_lower(db), schema.lower()) in self.schemas

def dump_graph(self):
"""Dump a key-only representation of the schema to a dictionary. Every
Expand All @@ -238,7 +242,7 @@ def dump_graph(self):
for k, v in self.relations.items()
}

def _setdefault(self, relation):
def _setdefault(self, relation: _CachedRelation):
"""Add a relation to the cache, or return it if it already exists.
:param _CachedRelation relation: The relation to set or get.
Expand Down Expand Up @@ -275,6 +279,8 @@ def _add_link(self, referenced_key, dependent_key):
.format(dependent_key)
)

assert dependent is not None # we just raised!

referenced.add_reference(dependent)

def add_link(self, referenced, dependent):
Expand Down Expand Up @@ -305,15 +311,15 @@ def add_link(self, referenced, dependent):
if ref_key not in self.relations:
# Insert a dummy "external" relation.
referenced = referenced.replace(
type=referenced.RelationType.External
type=referenced.External
)
self.add(referenced)

dep_key = _make_key(dependent)
if dep_key not in self.relations:
# Insert a dummy "external" relation.
dependent = dependent.replace(
type=referenced.RelationType.External
type=referenced.External
)
self.add(dependent)
logger.debug(
Expand Down Expand Up @@ -469,7 +475,9 @@ def rename(self, old, new):

lazy_log('after rename: {!s}', self.dump_graph)

def get_relations(self, database, schema):
def get_relations(
self, database: Optional[str], schema: Optional[str]
) -> List[Any]:
"""Case-insensitively yield all relations matching the given schema.
:param str schema: The case-insensitive schema name to list from.
Expand Down Expand Up @@ -498,7 +506,7 @@ def clear(self):
self.schemas.clear()

def _list_relations_in_schema(
self, database: str, schema: str
self, database: Optional[str], schema: Optional[str]
) -> List[_CachedRelation]:
"""Get the relations in a schema. Callers should hold the lock."""
key = (_lower(database), _lower(schema))
Expand Down
5 changes: 3 additions & 2 deletions core/dbt/adapters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def load_plugin(self, name: str) -> Type[Credentials]:
# and adapter_type entries with the same value, as they're all
# singletons
try:
mod = import_module('.' + name, 'dbt.adapters')
# mypy doesn't think modules have any attributes.
mod: Any = import_module('.' + name, 'dbt.adapters')
except ModuleNotFoundError as exc:
# if we failed to import the target module in particular, inform
# the user about it via a runtiem error
Expand All @@ -56,7 +57,7 @@ def load_plugin(self, name: str) -> Type[Credentials]:
# library. Log the stack trace.
logger.debug('', exc_info=True)
raise
plugin = mod.Plugin # type: AdapterPlugin
plugin: AdapterPlugin = mod.Plugin
plugin_type = plugin.adapter.type()

if plugin_type != name:
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ def list_relations_without_caching(
}
for _database, name, _schema, _type in results:
try:
_type = self.Relation.RelationType(_type)
_type = self.Relation.get_relation_type(_type)
except ValueError:
_type = self.Relation.RelationType.External
_type = self.Relation.External
relations.append(self.Relation.create(
database=_database,
schema=_schema,
Expand Down
1 change: 1 addition & 0 deletions core/dbt/clients/_jinja_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def find_blocks(self, allowed_blocks=None, collect_raw_data=True):

elif self.is_current_end(tag):
self.last_position = tag.end
assert self.current is not None
yield BlockTag(
block_type_name=self.current.block_type_name,
block_name=self.current.block_name,
Expand Down
4 changes: 4 additions & 0 deletions core/dbt/clients/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False,
logger.debug('Updating existing dependency {}.', directory)
else:
matches = re.match("Cloning into '(.+)'", err.decode('utf-8'))
if matches is None:
raise dbt.exceptions.RuntimeException(
f'Error cloning {repo} - never saw "Cloning into ..." from git'
)
directory = matches.group(1)
logger.debug('Pulling new dependency {}.', directory)
full_path = os.path.join(cwd, directory)
Expand Down
Loading

0 comments on commit 9cc7a7a

Please sign in to comment.