Skip to content

Commit

Permalink
Fix type annotations in lsst.daf.butler.tests
Browse files Browse the repository at this point in the history
  • Loading branch information
timj committed Nov 17, 2022
1 parent b839107 commit 3c73d91
Show file tree
Hide file tree
Showing 12 changed files with 287 additions and 164 deletions.
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ warn_unreachable = False
[mypy-lsst.daf.butler.registry.queries.expressions.parser.ply.*]
ignore_errors = True

[mypy-lsst.daf.butler.tests.*]
ignore_errors = True

[mypy-lsst.daf.butler.registry.tests.*]
ignore_errors = True

Expand Down
23 changes: 22 additions & 1 deletion python/lsst/daf/butler/core/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,20 @@
import traceback
from contextlib import contextmanager
from logging import Formatter, LogRecord, StreamHandler
from typing import IO, Any, Callable, ClassVar, Dict, Generator, Iterable, Iterator, List, Optional, Union
from typing import (
IO,
Any,
Callable,
ClassVar,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Union,
overload,
)

from lsst.utils.introspection import get_full_type_name
from lsst.utils.iteration import isplit
Expand Down Expand Up @@ -467,6 +480,14 @@ def __iter__(self) -> Iterator[ButlerLogRecord]: # type: ignore
def __setitem__(self, index: int, value: Record) -> None:
self.__root__[index] = self._validate_record(value)

@overload
def __getitem__(self, index: int) -> ButlerLogRecord:
...

@overload
def __getitem__(self, index: slice) -> "ButlerLogRecords":
...

def __getitem__(self, index: Union[slice, int]) -> "Union[ButlerLogRecords, ButlerLogRecord]":
# Handles slices and returns a new collection in that
# case.
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/core/storageClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def validateParameters(self, parameters: Collection | None = None) -> None:
raise KeyError(f"Parameter{s} '{unknown}' not understood by StorageClass {self.name}")

def filterParameters(
self, parameters: Mapping[str, Any], subset: Collection | None = None
self, parameters: Mapping[str, Any] | None, subset: Collection | None = None
) -> Mapping[str, Any]:
"""Filter out parameters that are not known to this `StorageClass`.
Expand Down
66 changes: 54 additions & 12 deletions python/lsst/daf/butler/tests/_datasetsHelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

__all__ = (
"DatasetTestHelper",
"DatastoreTestHelper",
Expand All @@ -28,25 +30,57 @@
)

import os
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any

from lsst.daf.butler import DatasetRef, DatasetType, StorageClass
from lsst.daf.butler.formatters.yaml import YamlFormatter

if TYPE_CHECKING:
from lsst.daf.butler import (
Config,
DataCoordinate,
DatasetId,
Datastore,
Dimension,
DimensionGraph,
Registry,
)


class DatasetTestHelper:
"""Helper methods for Datasets"""

id: int = 0
"""Instance self.id should be reset in setUp."""

def makeDatasetRef(
self, datasetTypeName, dimensions, storageClass, dataId, *, id=None, run=None, conform=True
):
self,
datasetTypeName: str,
dimensions: DimensionGraph | Iterable[str | Dimension],
storageClass: StorageClass | str,
dataId: DataCoordinate,
*,
id: DatasetId | None = None,
run: str | None = None,
conform: bool = True,
) -> DatasetRef:
"""Make a DatasetType and wrap it in a DatasetRef for a test"""
return self._makeDatasetRef(
datasetTypeName, dimensions, storageClass, dataId, id=id, run=run, conform=conform
)

def _makeDatasetRef(
self, datasetTypeName, dimensions, storageClass, dataId, *, id=None, run=None, conform=True
):
self,
datasetTypeName: str,
dimensions: DimensionGraph | Iterable[str | Dimension],
storageClass: StorageClass | str,
dataId: DataCoordinate,
*,
id: DatasetId | None = None,
run: str | None = None,
conform: bool = True,
) -> DatasetRef:
# helper for makeDatasetRef

# Pretend we have a parent if this looks like a composite
Expand All @@ -56,6 +90,7 @@ def _makeDatasetRef(
datasetType = DatasetType(
datasetTypeName, dimensions, storageClass, parentStorageClass=parentStorageClass
)

if id is None:
self.id += 1
id = self.id
Expand All @@ -67,7 +102,13 @@ def _makeDatasetRef(
class DatastoreTestHelper:
"""Helper methods for Datastore tests"""

def setUpDatastoreTests(self, registryClass, configClass):
root: str
id: int
config: Config
datastoreType: type[Datastore]
configFile: str

def setUpDatastoreTests(self, registryClass: type[Registry], configClass: type[Config]) -> None:
"""Shared setUp code for all Datastore tests"""
self.registry = registryClass()

Expand All @@ -81,7 +122,7 @@ def setUpDatastoreTests(self, registryClass, configClass):
if self.root is not None:
self.datastoreType.setConfigRoot(self.root, self.config, self.config.copy())

def makeDatastore(self, sub=None):
def makeDatastore(self, sub: str | None = None) -> Datastore:
"""Make a new Datastore instance of the appropriate type.
Parameters
Expand Down Expand Up @@ -113,10 +154,10 @@ def makeDatastore(self, sub=None):
class BadWriteFormatter(YamlFormatter):
"""A formatter that never works but does leave a file behind."""

def _readFile(self, path, pytype=None):
def _readFile(self, path: str, pytype: type[Any] | None = None) -> Any:
raise NotImplementedError("This formatter can not read anything")

def _writeFile(self, inMemoryDataset):
def _writeFile(self, inMemoryDataset: Any) -> None:
"""Write an empty file and then raise an exception."""
with open(self.fileDescriptor.location.path, "wb"):
pass
Expand All @@ -126,21 +167,22 @@ def _writeFile(self, inMemoryDataset):
class BadNoWriteFormatter(BadWriteFormatter):
"""A formatter that always fails without writing anything."""

def _writeFile(self, inMemoryDataset):
def _writeFile(self, inMemoryDataset: Any) -> None:
raise RuntimeError("Did not writing anything at all")


class MultiDetectorFormatter(YamlFormatter):
def _writeFile(self, inMemoryDataset):
def _writeFile(self, inMemoryDataset: Any) -> None:
raise NotImplementedError("Can not write")

def _fromBytes(self, serializedDataset, pytype=None):
def _fromBytes(self, serializedDataset: Any, pytype: type[Any] | None = None) -> Any:
data = super()._fromBytes(serializedDataset)
if self.dataId is None:
raise RuntimeError("This formatter requires a dataId")
if "detector" not in self.dataId:
if "detector" not in self.dataId: # type: ignore[comparison-overlap]
raise RuntimeError("This formatter requires detector to be present in dataId")
key = f"detector{self.dataId['detector']}"
assert pytype is not None
if key in data:
return pytype(data[key])
raise RuntimeError(f"Could not find '{key}' in data file")
44 changes: 23 additions & 21 deletions python/lsst/daf/butler/tests/_dummyRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations

__all__ = ("DummyRegistry",)


from typing import Any, Iterable, Iterator, Optional, Type
from collections.abc import Iterable, Iterator
from typing import Any

import sqlalchemy
from lsst.daf.butler import DatasetRef, DimensionUniverse, ddl
from lsst.daf.butler import DimensionUniverse, ddl
from lsst.daf.butler.registry.bridge.ephemeral import EphemeralDatastoreRegistryBridge
from lsst.daf.butler.registry.interfaces import (
Database,
DatasetIdRef,
DatasetRecordStorageManager,
DatastoreRegistryBridge,
DatastoreRegistryBridgeManager,
Expand All @@ -40,12 +42,12 @@


class DummyOpaqueTableStorage(OpaqueTableStorage):
def __init__(self, name: str, spec: ddl.TableSpec):
def __init__(self, name: str, spec: ddl.TableSpec) -> None:
super().__init__(name=name)
self._rows = []
self._rows: list[dict] = []
self._spec = spec

def insert(self, *data: dict):
def insert(self, *data: dict) -> None:
# Docstring inherited from OpaqueTableStorage.
uniqueConstraints = list(self._spec.unique)
uniqueConstraints.append(tuple(field.name for field in self._spec.fields if field.primaryKey))
Expand Down Expand Up @@ -82,7 +84,7 @@ def fetch(self, **where: Any) -> Iterator[dict]:
else:
yield d

def delete(self, columns: Iterable[str], *rows: dict):
def delete(self, columns: Iterable[str], *rows: dict) -> None:
# Docstring inherited from OpaqueTableStorage.
kept_rows = []
for table_row in self._rows:
Expand All @@ -95,29 +97,29 @@ def delete(self, columns: Iterable[str], *rows: dict):


class DummyOpaqueTableStorageManager(OpaqueTableStorageManager):
def __init__(self):
self._storages = {}
def __init__(self) -> None:
self._storages: dict[str, DummyOpaqueTableStorage] = {}

@classmethod
def initialize(cls, db: Database, context: StaticTablesContext) -> OpaqueTableStorageManager:
# Docstring inherited from OpaqueTableStorageManager.
# Not used, but needed to satisfy ABC requirement.
return cls()

def get(self, name: str) -> Optional[OpaqueTableStorage]:
def get(self, name: str) -> OpaqueTableStorage | None:
# Docstring inherited from OpaqueTableStorageManager.
return self._storage.get(name)
return self._storages.get(name)

def register(self, name: str, spec: ddl.TableSpec) -> OpaqueTableStorage:
# Docstring inherited from OpaqueTableStorageManager.
return self._storages.setdefault(name, DummyOpaqueTableStorage(name, spec))

@classmethod
def currentVersion(cls) -> Optional[VersionTuple]:
def currentVersion(cls) -> VersionTuple | None:
# Docstring inherited from VersionedExtension.
return None

def schemaDigest(self) -> Optional[str]:
def schemaDigest(self) -> str | None:
# Docstring inherited from VersionedExtension.
return None

Expand All @@ -127,7 +129,7 @@ def __init__(
self, opaque: OpaqueTableStorageManager, universe: DimensionUniverse, datasetIdColumnType: type
):
super().__init__(opaque=opaque, universe=universe, datasetIdColumnType=datasetIdColumnType)
self._bridges = {}
self._bridges: dict[str, EphemeralDatastoreRegistryBridge] = {}

@classmethod
def initialize(
Expand All @@ -136,46 +138,46 @@ def initialize(
context: StaticTablesContext,
*,
opaque: OpaqueTableStorageManager,
datasets: Type[DatasetRecordStorageManager],
datasets: type[DatasetRecordStorageManager],
universe: DimensionUniverse,
) -> DatastoreRegistryBridgeManager:
# Docstring inherited from DatastoreRegistryBridgeManager
# Not used, but needed to satisfy ABC requirement.
return cls(opaque=opaque, universe=universe, datasetIdColumnType=datasets.getIdColumnType())

def refresh(self):
def refresh(self) -> None:
# Docstring inherited from DatastoreRegistryBridgeManager
pass

def register(self, name: str, *, ephemeral: bool = False) -> DatastoreRegistryBridge:
# Docstring inherited from DatastoreRegistryBridgeManager
return self._bridges.setdefault(name, EphemeralDatastoreRegistryBridge(name))

def findDatastores(self, ref: DatasetRef) -> Iterable[str]:
def findDatastores(self, ref: DatasetIdRef) -> Iterable[str]:
# Docstring inherited from DatastoreRegistryBridgeManager
for name, bridge in self._bridges.items():
if ref in bridge:
yield name

@classmethod
def currentVersion(cls) -> Optional[VersionTuple]:
def currentVersion(cls) -> VersionTuple | None:
# Docstring inherited from VersionedExtension.
return None

def schemaDigest(self) -> Optional[str]:
def schemaDigest(self) -> str | None:
# Docstring inherited from VersionedExtension.
return None


class DummyRegistry:
"""Dummy Registry, for Datastore test purposes."""

def __init__(self):
def __init__(self) -> None:
self._opaque = DummyOpaqueTableStorageManager()
self.dimensions = DimensionUniverse()
self._datastoreBridges = DummyDatastoreRegistryBridgeManager(
self._opaque, self.dimensions, sqlalchemy.BigInteger
)

def getDatastoreBridgeManager(self):
def getDatastoreBridgeManager(self) -> DatastoreRegistryBridgeManager:
return self._datastoreBridges

0 comments on commit 3c73d91

Please sign in to comment.