Skip to content

Commit

Permalink
Merge pull request #276 from lsst/tickets/DM-24780
Browse files Browse the repository at this point in the history
DM-24780: add initial mypy support.
  • Loading branch information
TallJimbo committed May 8, 2020
2 parents ed29410 + 11cb3eb commit 4a27313
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 43 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ bin/*
.pytest_cache/
.mypy_cache/
.idea/
.vscode/
.vscode/
mypy.log
5 changes: 4 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ matrix:
- python: '3.7'
install:
- pip install flake8
script: flake8
- pip install mypy
script:
- flake8
- cd python && mypy -p lsst.daf.butler
6 changes: 6 additions & 0 deletions python/SConscript
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -*- python -*-
from lsst.sconsUtils import state
mypy = state.env.Command("../mypy.log", "lsst/daf/butler",
"cd python && mypy -p lsst.daf.butler 2>&1 | tee -a ../mypy.log")

state.env.Alias("mypy", mypy)
15 changes: 8 additions & 7 deletions python/lsst/daf/butler/registry/interfaces/_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from .._collectionType import CollectionType

if TYPE_CHECKING:
from .database import Database, StaticTablesContext
from ._database import Database, StaticTablesContext


class MissingCollectionError(Exception):
Expand Down Expand Up @@ -94,7 +94,8 @@ class RunRecord(CollectionRecord):
"""

@abstractmethod
def update(self, host: Optional[str] = None, timespan: Optional[Timespan[astropy.time.Time]] = None):
def update(self, host: Optional[str] = None,
timespan: Optional[Timespan[astropy.time.Time]] = None) -> None:
"""Update the database record for this run with new execution
information.
Expand Down Expand Up @@ -156,7 +157,7 @@ def children(self) -> CollectionSearch:
"""
return self._children

def update(self, manager: CollectionManager, children: CollectionSearch):
def update(self, manager: CollectionManager, children: CollectionSearch) -> None:
"""Redefine this chain to search the given child collections.
This method should be used by all external code to set children. It
Expand Down Expand Up @@ -184,7 +185,7 @@ def update(self, manager: CollectionManager, children: CollectionSearch):
self._update(manager, children)
self._children = children

def refresh(self, manager: CollectionManager):
def refresh(self, manager: CollectionManager) -> None:
"""Load children from the database, using the given manager to resolve
collection primary key values into records.
Expand All @@ -203,7 +204,7 @@ def refresh(self, manager: CollectionManager):
self._children = self._load(manager)

@abstractmethod
def _update(self, manager: CollectionManager, children: CollectionSearch):
def _update(self, manager: CollectionManager, children: CollectionSearch) -> None:
"""Protected implementation hook for setting the `children` property.
This method should be implemented by subclasses to update the database
Expand Down Expand Up @@ -375,7 +376,7 @@ def getRunForeignKeyName(cls, prefix: str = "run") -> str:
raise NotImplementedError()

@abstractmethod
def refresh(self):
def refresh(self) -> None:
"""Ensure all other operations on this manager are aware of any
collections that may have been registered by other clients since it
was initialized or last refreshed.
Expand Down Expand Up @@ -420,7 +421,7 @@ def register(self, name: str, type: CollectionType) -> CollectionRecord:
raise NotImplementedError()

@abstractmethod
def remove(self, name: str):
def remove(self, name: str) -> None:
"""Completely remove a collection.
Any existing `CollectionRecord` objects that correspond to the removed
Expand Down
46 changes: 27 additions & 19 deletions python/lsst/daf/butler/registry/interfaces/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
)
import warnings
Expand All @@ -47,7 +49,7 @@
from .._exceptions import ConflictingDefinitionError


def _checkExistingTableDefinition(name: str, spec: ddl.TableSpec, inspection: Dict[str, Any]):
def _checkExistingTableDefinition(name: str, spec: ddl.TableSpec, inspection: List[Dict[str, Any]]) -> None:
"""Test that the definition of a table in a `ddl.TableSpec` and from
database introspection are consistent.
Expand Down Expand Up @@ -95,7 +97,7 @@ class StaticTablesContext:

def __init__(self, db: Database):
self._db = db
self._foreignKeys = []
self._foreignKeys: List[Tuple[sqlalchemy.schema.Table, sqlalchemy.schema.ForeignKeyConstraint]] = []
self._inspector = sqlalchemy.engine.reflection.Inspector(self._db._connection)
self._tableNames = frozenset(self._inspector.get_table_names(schema=self._db.namespace))

Expand Down Expand Up @@ -136,7 +138,8 @@ def addTableTuple(self, specs: Tuple[ddl.TableSpec, ...]) -> Tuple[sqlalchemy.sc
is just a factory for `type` objects, not an actual type itself,
we cannot represent this with type annotations.
"""
return specs._make(self.addTable(name, spec) for name, spec in zip(specs._fields, specs))
return specs._make(self.addTable(name, spec) # type: ignore
for name, spec in zip(specs._fields, specs)) # type: ignore


class Database(ABC):
Expand Down Expand Up @@ -173,8 +176,7 @@ class Database(ABC):
`Database` itself has several underscore-prefixed attributes:
- ``_cs``: SQLAlchemy objects representing the connection and transaction
state.
- ``_connection``: SQLAlchemy object representing the connection.
- ``_metadata``: the `sqlalchemy.schema.MetaData` object representing
the tables and other schema entities.
Expand All @@ -188,7 +190,7 @@ def __init__(self, *, origin: int, connection: sqlalchemy.engine.Connection,
self.origin = origin
self.namespace = namespace
self._connection = connection
self._metadata = None
self._metadata: Optional[sqlalchemy.schema.MetaData] = None

@classmethod
def makeDefaultUri(cls, root: str) -> Optional[str]:
Expand Down Expand Up @@ -299,7 +301,7 @@ def fromConnection(cls, connection: sqlalchemy.engine.Connection, *, origin: int
raise NotImplementedError()

@contextmanager
def transaction(self, *, interrupting: bool = False) -> None:
def transaction(self, *, interrupting: bool = False) -> Iterator:
"""Return a context manager that represents a transaction.
Parameters
Expand All @@ -326,7 +328,7 @@ def transaction(self, *, interrupting: bool = False) -> None:
raise

@contextmanager
def declareStaticTables(self, *, create: bool) -> StaticTablesContext:
def declareStaticTables(self, *, create: bool) -> Iterator[StaticTablesContext]:
"""Return a context manager in which the database's static DDL schema
can be declared.
Expand Down Expand Up @@ -466,7 +468,7 @@ def _mangleTableName(self, name: str) -> str:
return name

def _convertFieldSpec(self, table: str, spec: ddl.FieldSpec, metadata: sqlalchemy.MetaData,
**kwds) -> sqlalchemy.schema.Column:
**kwds: Any) -> sqlalchemy.schema.Column:
"""Convert a `FieldSpec` to a `sqlalchemy.schema.Column`.
Parameters
Expand Down Expand Up @@ -501,7 +503,7 @@ def _convertFieldSpec(self, table: str, spec: ddl.FieldSpec, metadata: sqlalchem
comment=spec.doc, **kwds)

def _convertForeignKeySpec(self, table: str, spec: ddl.ForeignKeySpec, metadata: sqlalchemy.MetaData,
**kwds) -> sqlalchemy.schema.ForeignKeyConstraint:
**kwds: Any) -> sqlalchemy.schema.ForeignKeyConstraint:
"""Convert a `ForeignKeySpec` to a
`sqlalchemy.schema.ForeignKeyConstraint`.
Expand Down Expand Up @@ -537,7 +539,7 @@ def _convertForeignKeySpec(self, table: str, spec: ddl.ForeignKeySpec, metadata:
)

def _convertTableSpec(self, name: str, spec: ddl.TableSpec, metadata: sqlalchemy.MetaData,
**kwds) -> sqlalchemy.schema.Table:
**kwds: Any) -> sqlalchemy.schema.Table:
"""Convert a `TableSpec` to a `sqlalchemy.schema.Table`.
Parameters
Expand Down Expand Up @@ -702,7 +704,7 @@ def sync(self, table: sqlalchemy.schema.Table, *,
compared: Optional[Dict[str, Any]] = None,
extra: Optional[Dict[str, Any]] = None,
returning: Optional[Sequence[str]] = None,
) -> Tuple[Optional[Dict[str, Any], bool]]:
) -> Tuple[Optional[Dict[str, Any]], bool]:
"""Insert into a table as necessary to ensure database contains
values equivalent to the given ones.
Expand Down Expand Up @@ -749,7 +751,7 @@ def sync(self, table: sqlalchemy.schema.Table, *,
already exist.
"""

def check():
def check() -> Tuple[int, Optional[List[str]], Optional[List]]:
"""Query for a row that matches the ``key`` argument, and compare
to what was given by the caller.
Expand All @@ -768,7 +770,7 @@ def check():
Results in the database that correspond to the columns given
in ``returning``, or `None` if ``returning is None``.
"""
toSelect = set()
toSelect: Set[str] = set()
if compared is not None:
toSelect.update(compared.keys())
if returning is not None:
Expand All @@ -788,7 +790,7 @@ def check():
existing = fetched[0]
if compared is not None:

def safeNotEqual(a, b):
def safeNotEqual(a: Any, b: Any) -> bool:
if isinstance(a, astropy.time.Time):
return not time_utils.times_equal(a, b)
return a != b
Expand All @@ -799,7 +801,7 @@ def safeNotEqual(a, b):
else:
inconsistencies = []
if returning is not None:
toReturn = [existing[k] for k in returning]
toReturn: Optional[list] = [existing[k] for k in returning]
else:
toReturn = None
return 1, inconsistencies, toReturn
Expand Down Expand Up @@ -873,7 +875,11 @@ def safeNotEqual(a, b):
elif bad:
raise DatabaseConflictError(f"Conflict in sync on column(s) {bad}.")
inserted = False
return {k: v for k, v in zip(returning, result)} if returning is not None else None, inserted
if returning is None:
return None, inserted
else:
assert result is not None
return {k: v for k, v in zip(returning, result)}, inserted

def insert(self, table: sqlalchemy.schema.Table, *rows: dict, returnIds: bool = False,
) -> Optional[List[int]]:
Expand Down Expand Up @@ -919,12 +925,13 @@ def insert(self, table: sqlalchemy.schema.Table, *rows: dict, returnIds: bool =
raise ReadOnlyDatabaseError(f"Attempt to insert into read-only database '{self}'.")
if not returnIds:
self._connection.execute(table.insert(), *rows)
return None
else:
sql = table.insert()
return [self._connection.execute(sql, row).inserted_primary_key[0] for row in rows]

@abstractmethod
def replace(self, table: sqlalchemy.schema.Table, *rows: dict):
def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
"""Insert one or more rows into a table, replacing any existing rows
for which insertion of a new row would violate the primary key
constraint.
Expand Down Expand Up @@ -1043,7 +1050,8 @@ def update(self, table: sqlalchemy.schema.Table, where: Dict[str, str], *rows: d
)
return self._connection.execute(sql, *rows).rowcount

def query(self, sql: sqlalchemy.sql.FromClause, *args, **kwds) -> sqlalchemy.engine.ResultProxy:
def query(self, sql: sqlalchemy.sql.FromClause,
*args: Any, **kwds: Any) -> sqlalchemy.engine.ResultProxy:
"""Run a SELECT query against the database.
Parameters
Expand Down
13 changes: 7 additions & 6 deletions python/lsst/daf/butler/registry/interfaces/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Iterable,
Iterator,
Expand Down Expand Up @@ -115,7 +116,7 @@ def find(self, collection: CollectionRecord, dataId: DataCoordinate) -> Optional
raise NotImplementedError()

@abstractmethod
def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]):
def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
"""Associate one or more datasets with a collection.
Parameters
Expand Down Expand Up @@ -144,7 +145,7 @@ def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]
raise NotImplementedError()

@abstractmethod
def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]):
def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
"""Remove one or more datasets from a collection.
Parameters
Expand Down Expand Up @@ -245,9 +246,9 @@ def initialize(cls, db: Database, context: StaticTablesContext, *, collections:

@classmethod
@abstractmethod
def addDatasetForeignKey(cls, tableSpec: ddl.TableSpec, *, name: str = "dataset",
constraint: bool = True, onDelete: Optional[str] = None,
**kwargs) -> ddl.FieldSpec:
def addDatasetForeignKey(cls, tableSpec: ddl.TableSpec, *,
name: str = "dataset", constraint: bool = True, onDelete: Optional[str] = None,
**kwargs: Any) -> ddl.FieldSpec:
"""Add a foreign key (field and constraint) referencing the dataset
table.
Expand Down Expand Up @@ -279,7 +280,7 @@ def addDatasetForeignKey(cls, tableSpec: ddl.TableSpec, *, name: str = "dataset"
raise NotImplementedError()

@abstractmethod
def refresh(self, *, universe: DimensionUniverse):
def refresh(self, *, universe: DimensionUniverse) -> None:
"""Ensure all other operations on this manager are aware of any
dataset types that may have been registered by other clients since
it was initialized or last refreshed.
Expand Down
10 changes: 5 additions & 5 deletions python/lsst/daf/butler/registry/interfaces/_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def element(self) -> DimensionElement:
raise NotImplementedError()

@abstractmethod
def clearCaches(self):
def clearCaches(self) -> None:
"""Clear any in-memory caches held by the storage instance.
This is called by `Registry` when transactions are rolled back, to
Expand All @@ -148,7 +148,7 @@ def join(
builder: QueryBuilder, *,
regions: Optional[NamedKeyDict[DimensionElement, sqlalchemy.sql.ColumnElement]] = None,
timespans: Optional[NamedKeyDict[DimensionElement, Timespan[sqlalchemy.sql.ColumnElement]]] = None,
):
) -> sqlalchemy.sql.FromClause:
"""Add the dimension element's logical table to a query under
construction.
Expand Down Expand Up @@ -187,7 +187,7 @@ def join(
raise NotImplementedError()

@abstractmethod
def insert(self, *records: DimensionRecord):
def insert(self, *records: DimensionRecord) -> None:
"""Insert one or more records into storage.
Parameters
Expand Down Expand Up @@ -313,7 +313,7 @@ def initialize(cls, db: Database, context: StaticTablesContext, *,
raise NotImplementedError()

@abstractmethod
def refresh(self):
def refresh(self) -> None:
"""Ensure all other operations on this manager are aware of any
dataset types that may have been registered by other clients since
it was initialized or last refreshed.
Expand Down Expand Up @@ -379,7 +379,7 @@ def register(self, element: DimensionElement) -> DimensionRecordStorage:
raise NotImplementedError()

@abstractmethod
def clearCaches(self):
def clearCaches(self) -> None:
"""Clear any in-memory caches held by nested `DimensionRecordStorage`
instances.
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/daf/butler/registry/interfaces/_opaque.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, name: str):
self.name = name

@abstractmethod
def insert(self, *data: dict):
def insert(self, *data: dict) -> None:
"""Insert records into the table
Parameters
Expand Down Expand Up @@ -80,7 +80,7 @@ def fetch(self, **where: Any) -> Iterator[dict]:
raise NotImplementedError()

@abstractmethod
def delete(self, **where: Any):
def delete(self, **where: Any) -> None:
"""Remove records from an opaque table.
Parameters
Expand Down

0 comments on commit 4a27313

Please sign in to comment.