Skip to content

Commit

Permalink
Mypy cleanup (#1562)
Browse files Browse the repository at this point in the history
* Update mypy config to Py3.10

* Fix typing errors in model/lineage.py

* Miscellaneous type cleanup.

* Miscellaneous type cleanup - tweaks

* Cleanup type hints in index/abstract.py

* Cleanup type hints in memory index driver.

* Remove incomplete type hints in null index driver.

* Minor tweaks.

* Tests fixed.

* Fix type hints in moedel/model.py

* Fix type hints in Postgres driver.

* Fix type hints in Postgris driver (drivers tree).

* More typecheck cleanup and fix some field ambiguities.

* Less than 60 mypy errors left and tests all passing.

* Postgis driver done - 40 errors to go.

* Fix test regressions.

* Fix mypy issue in scripts directory.

* Fix or suppress mypy issues in virtual products.

* Fix or suppress mypy errores in api core.

* Fix or suppress mypy errors in model properties

* Fix or suppress mypy errors in storage layer.

* Run MyPy in GH action

* Lintage.

* Install type stubs for mypy checks.

* Oops type typo

* Minor cleanup.

* Update whats_new.rst.
  • Loading branch information
SpacemanPaul committed Mar 14, 2024
1 parent 5b0d0c7 commit a2f71d0
Show file tree
Hide file tree
Showing 48 changed files with 580 additions and 465 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,30 @@ jobs:
sudo apt-get remove python3-openssl
pip install --upgrade -e '.[test]'
pylint -j 2 --reports no datacube
mypy:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
name: MyPy
steps:
- name: checkout git
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup conda
uses: s-weigand/setup-conda@v1
with:
update-conda: true
python-version: ${{ matrix.python-version }}
conda-channels: anaconda, conda-forge
- name: run mypy
run: |
sudo apt-get remove python3-openssl
pip install --upgrade -e '.[types]'
mypy datacube
pycodestyle:
Expand Down
2 changes: 1 addition & 1 deletion datacube/api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ def _make_dask_array(chunked_srcs,
# B BR
empties: dict[tuple[int, int], str] = {}

def _mk_empty(shape: tuple[int, ...]) -> str:
def _mk_empty(shape: tuple[int, int]) -> str:
name = empties.get(shape, None)
if name is not None:
return name
Expand Down
9 changes: 5 additions & 4 deletions datacube/cfg/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import warnings

from os import PathLike
from typing import Any, TypeAlias, Union
from typing import Any, TypeAlias, Union, cast

from .cfg import find_config, parse_text
from .exceptions import ConfigException
Expand Down Expand Up @@ -111,9 +111,10 @@ def __init__(
text = find_config(paths)

self.raw_text = text
self.raw_config = raw_dict
if not self.raw_config:
self.raw_config = parse_text(self.raw_text)
if raw_dict is not None:
self.raw_config = raw_dict
else:
self.raw_config = parse_text(cast(str, self.raw_text))

self._aliases: dict[str, str] = {}
self.known_environments: dict[str, ODCEnvironment] = {
Expand Down
11 changes: 6 additions & 5 deletions datacube/cfg/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,24 @@ def find_config(paths_in: None | str | PathLike | list[str | PathLike]) -> str:
:return: The contents of the first readable file found.
"""
using_default_paths: bool = False
paths: list[str | PathLike] = []
if paths_in is None:
if os.environ.get("ODC_CONFIG_PATH"):
paths: list[str | PathLike] = os.environ["ODC_CONFIG_PATH"].split(':')
paths.extend(os.environ["ODC_CONFIG_PATH"].split(':'))
elif os.environ.get("DATACUBE_CONFIG_PATH"):
warnings.warn(
"Datacube config path being determined by legacy $DATACUBE_CONFIG_PATH environment variable. "
"This environment variable is deprecated and the behaviour of it has changed somewhat since datacube "
"1.8.x. Please refer to the documentation for details and switch to $ODC_CONFIG_PATH"
)
paths = os.environ["DATACUBE_CONFIG_PATH"].split(':')
paths.extend(os.environ["DATACUBE_CONFIG_PATH"].split(':'))
else:
paths: list[str | PathLike] = _DEFAULT_CONFIG_SEARCH_PATH
paths.extend(_DEFAULT_CONFIG_SEARCH_PATH)
using_default_paths = True
elif isinstance(paths_in, str) or isinstance(paths_in, PathLike):
paths = [paths_in]
paths.append(paths_in)
else:
paths = paths_in
paths.extend(paths_in)

for path in paths:
try:
Expand Down
3 changes: 2 additions & 1 deletion datacube/cfg/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
try:
import pwd

_DEFAULT_DB_USER = pwd.getpwuid(os.geteuid()).pw_name # type: Optional[str]
_DEFAULT_DB_USER: str | None = pwd.getpwuid(os.geteuid()).pw_name
except (ImportError, KeyError):
# No default on Windows and some other systems
_DEFAULT_DB_USER = None
Expand Down Expand Up @@ -140,6 +140,7 @@ def handle_dependent_options(self, value: Any) -> None:
# Get driver-specific config options
from datacube.drivers.indexes import index_driver_by_name
driver = index_driver_by_name(value)
assert driver is not None
for option in driver.get_config_option_handlers(self.env):
self.env._option_handlers.append(option)

Expand Down
1 change: 1 addition & 0 deletions datacube/drivers/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,6 @@ def formats(self) -> List[str]:
def supports(self, protocol: str, fmt: str) -> bool:
... # pragma: no cover

@abstractmethod
def new_instance(self, cfg: dict) -> ReaderDriver:
... # pragma: no cover
36 changes: 19 additions & 17 deletions datacube/drivers/postgis/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from sqlalchemy import select, text, and_, or_, func
from sqlalchemy.dialects.postgresql import INTERVAL
from sqlalchemy.exc import IntegrityError
from typing import Iterable, Sequence, Optional, Set
from sqlalchemy.engine import Row

from typing import Iterable, Sequence, Optional, Set, Any
from typing import cast as type_cast

from datacube.index.fields import OrExpression
from datacube.model import Range
Expand All @@ -33,20 +36,20 @@
from datacube.index.abstract import DSID
from datacube.model.lineage import LineageRelation, LineageDirection
from . import _core
from ._fields import parse_fields, Expression, PgField, PgExpression # noqa: F401
from ._fields import parse_fields, Expression, PgField, PgExpression, DateRangeDocField # noqa: F401
from ._fields import NativeField, DateDocField, SimpleDocField, UnindexableValue
from ._schema import MetadataType, Product, \
Dataset, DatasetLineage, DatasetLocation, SelectedDatasetLocation, \
search_field_index_map, search_field_indexes, DatasetHome
from ._spatial import geom_alchemy, generate_dataset_spatial_values, extract_geometry_from_eo3_projection
from .sql import escape_pg_identifier

from ...utils.changes import Offset

_LOG = logging.getLogger(__name__)


# Make a function because it's broken
def _dataset_select_fields():
def _dataset_select_fields() -> tuple:
return (
Dataset,
# All active URIs, from newest to oldest
Expand All @@ -66,7 +69,7 @@ def _dataset_select_fields():
)


def _dataset_bulk_select_fields():
def _dataset_bulk_select_fields() -> tuple:
return (
Dataset.product_ref,
Dataset.metadata_doc,
Expand All @@ -87,7 +90,7 @@ def _dataset_bulk_select_fields():
)


def get_native_fields():
def get_native_fields() -> dict[str, NativeField]:
# Native fields (hard-coded into the schema)
fields = {
'id': NativeField(
Expand Down Expand Up @@ -145,7 +148,7 @@ def get_native_fields():
return fields


def mk_simple_offset_field(field_name, description, offset):
def mk_simple_offset_field(field_name: str, description: str, offset: Offset) -> SimpleDocField:
return SimpleDocField(
name=field_name, description=description,
alchemy_column=Dataset.metadata_doc,
Expand Down Expand Up @@ -754,7 +757,7 @@ def search_unique_datasets(self, expressions, select_fields=None, limit=None, ar

return self._connection.execute(select_query)

def get_duplicates(self, match_fields: Sequence[PgField], expressions: Sequence[PgExpression]) -> Iterable[tuple]:
def get_duplicates(self, match_fields: Sequence[PgField], expressions: Sequence[PgExpression]) -> Iterable[Row]:
# TODO
if "time" in [f.name for f in match_fields]:
return self.get_duplicates_with_time(match_fields, expressions)
Expand All @@ -780,11 +783,11 @@ def get_duplicates(self, match_fields: Sequence[PgField], expressions: Sequence[

def get_duplicates_with_time(
self, match_fields: Sequence[PgField], expressions: Sequence[PgExpression]
) -> Iterable[tuple]:
) -> Iterable[Row]:
fields = []
for f in match_fields:
if f.name == "time":
time_field = f.expression_with_leniency
time_field = type_cast(DateRangeDocField, f).expression_with_leniency
else:
fields.append(f.alchemy_expression)

Expand Down Expand Up @@ -820,7 +823,7 @@ def get_duplicates_with_time(
*fields,
text("(lower(time_intersect) at time zone 'UTC', upper(time_intersect) at time zone 'UTC') as time")
).select_from(
time_overlap
time_overlap # type: ignore[arg-type]
).group_by(
*fields, text("time_intersect")
).having(
Expand Down Expand Up @@ -1131,7 +1134,7 @@ def update_metadata_type(self, name, definition):

def _get_active_field_names(fields, metadata_doc):
for field in fields.values():
if hasattr(field, 'extract'):
if field.can_extract:
try:
value = field.extract(metadata_doc)
if value is not None:
Expand Down Expand Up @@ -1279,13 +1282,12 @@ def create_user(self, username, password, role, description=None):
sql = text('comment on role {username} is :description'.format(username=username))
self._connection.execute(sql, {"description": description})

def drop_users(self, users: Iterable[str]) -> str:
def drop_users(self, users: Iterable[str]) -> None:
for username in users:
sql = text('drop role {username}'.format(username=escape_pg_identifier(self._connection, username)))
self._connection.execute(sql)

def grant_role(self, role, users):
# type: (str, Iterable[str]) -> None
def grant_role(self, role: str, users: Iterable[str]) -> None:
"""
Grant a role to a user.
"""
Expand All @@ -1297,7 +1299,7 @@ def grant_role(self, role, users):

_core.grant_role(self._connection, pg_role, users)

def insert_home(self, home, ids, allow_updates):
def insert_home(self, home: str, ids: Iterable[uuid.UUID], allow_updates: bool) -> int:
"""
Set home for multiple IDs (but one home value)
Expand Down Expand Up @@ -1379,7 +1381,7 @@ def write_relations(self, relations: Iterable[LineageRelation], allow_updates: b
:return: Count of database rows affected
"""
if allow_updates:
by_classifier = {}
by_classifier: dict[str, Any] = {}
for rel in relations:
db_repr = {
"derived_dataset_ref": rel.derived_id,
Expand Down
42 changes: 24 additions & 18 deletions datacube/drivers/postgis/_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from . import _api
from . import _core
from ._spatial import ensure_spindex, spindexes, spindex_for_crs, drop_spindex
from ._spatial import ensure_spindex, spindexes, spindex_for_crs, drop_spindex, crs_to_epsg
from ._schema import SpatialIndex
from ...cfg import ODCEnvironment, psql_url_from_config

Expand All @@ -38,7 +38,7 @@
_LOG = logging.getLogger(__name__)


class PostGisDb(object):
class PostGisDb:
"""
A thin database access api.
Expand All @@ -55,7 +55,7 @@ class PostGisDb(object):

driver_name = 'postgis' # Mostly to support parametised tests

def __init__(self, engine):
def __init__(self, engine: Engine):
# We don't recommend using this constructor directly as it may change.
# Use static methods PostGisDb.create() or PostGisDb.from_config()
self._engine = engine
Expand All @@ -65,13 +65,16 @@ def __init__(self, engine):
def from_config(cls,
config_env: ODCEnvironment,
application_name: str | None = None,
validate_connection: bool = True):
validate_connection: bool = True) -> "PostGisDb":
app_name = cls._expand_app_name(application_name)

return PostGisDb.create(config_env, application_name=app_name, validate=validate_connection)

@classmethod
def create(cls, config_env: ODCEnvironment, application_name: str | None = None, validate: bool = True):
def create(cls,
config_env: ODCEnvironment,
application_name: str | None = None,
validate: bool = True) -> "PostGisDb":
url = psql_url_from_config(config_env)
kwargs = {
"application_name": application_name,
Expand All @@ -96,7 +99,7 @@ def create(cls, config_env: ODCEnvironment, application_name: str | None = None,
return PostGisDb(engine)

@staticmethod
def _create_engine(url, application_name=None, iam_rds_auth=False, iam_rds_timeout=600, pool_timeout=60):
def _create_engine(url, application_name=None, iam_rds_auth=False, iam_rds_timeout=600, pool_timeout=60) -> Engine:
engine = create_engine(
url,
echo=False,
Expand All @@ -123,7 +126,7 @@ def _create_engine(url, application_name=None, iam_rds_auth=False, iam_rds_timeo
def url(self) -> EngineUrl:
return self._engine.url

def close(self):
def close(self) -> None:
"""
Close any idle connections in the pool.
Expand Down Expand Up @@ -165,7 +168,7 @@ def _expand_app_name(cls, application_name):
_LOG.warning('Application name is too long: Truncating to %s chars', (64 - len(_LIB_ID) - 1))
return full_name[-64:]

def init(self, with_permissions=True):
def init(self, with_permissions: bool = True) -> bool:
"""
Init a new database (if not already set up).
Expand All @@ -177,13 +180,14 @@ def init(self, with_permissions=True):

return is_new

def _refresh_spindexes(self):
def _refresh_spindexes(self) -> None:
self._spindexes = spindexes(self._engine)

@property
def spindexes(self) -> Mapping[int, Type[SpatialIndex]]:
if self._spindexes is None:
self._refresh_spindexes()
assert self._spindexes is not None # for type checker
return self._spindexes

def create_spatial_index(self, crs: CRS) -> Optional[Type[SpatialIndex]]:
Expand All @@ -193,14 +197,16 @@ def create_spatial_index(self, crs: CRS) -> Optional[Type[SpatialIndex]]:
:param crs:
:return:
"""
spidx = self.spindexes.get(crs.epsg)
if spidx is None:
spidx = spindex_for_crs(crs)
try:
spidx = self.spindexes.get(crs_to_epsg(crs))
if spidx is None:
_LOG.warning("Could not dynamically model an index for CRS %s", crs._str)
return None
ensure_spindex(self._engine, spidx)
self._refresh_spindexes()
spidx = spindex_for_crs(crs)
assert spidx is not None # for type checker
ensure_spindex(self._engine, spidx)
self._refresh_spindexes()
except ValueError:
_LOG.warning("Could not dynamically model an index for CRS %s", crs._str)
return None
return spidx

def drop_spatial_index(self, crs: CRS) -> bool:
Expand All @@ -210,15 +216,15 @@ def drop_spatial_index(self, crs: CRS) -> bool:
:param crs:
:return:
"""
spidx = self.spindexes.get(crs.epsg)
spidx = self.spindexes.get(crs_to_epsg(crs))
if spidx is None:
return False
result = drop_spindex(self._engine, spidx)
self._refresh_spindexes()
return result

def spatial_index(self, crs: CRS) -> Optional[Type[SpatialIndex]]:
return self.spindexes.get(crs.epsg)
return self.spindexes.get(crs_to_epsg(crs))

def spatially_indexed_crses(self, refresh=False) -> Iterable[CRS]:
if refresh:
Expand Down

0 comments on commit a2f71d0

Please sign in to comment.