Skip to content

Commit

Permalink
Ep13 Dataset API - Phase 1 (#1538)
Browse files Browse the repository at this point in the history
* Reimplement datasets.get() via new datasets.get_unsafe() method.

* Deprecate get_derived().

* temporal_extent signatures - full implementations to come.

* Implement full temporal_extent method in postgis driver.

* Test and debug temporal_extent

* Update whats_new.rst.

* lintage.

* more lintage.

* Test coverage

* Fix bug in memory driver implementation.
  • Loading branch information
SpacemanPaul committed Jan 29, 2024
1 parent bb42559 commit ad78e75
Show file tree
Hide file tree
Showing 11 changed files with 271 additions and 74 deletions.
36 changes: 36 additions & 0 deletions datacube/drivers/postgis/_api.py
Expand Up @@ -13,12 +13,14 @@
Persistence API implementation for postgis.
"""

import datetime
import json
import logging
import uuid # noqa: F401
from sqlalchemy import cast
from sqlalchemy import delete, update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.sql.expression import Select
from sqlalchemy import select, text, and_, or_, func
from sqlalchemy.dialects.postgresql import INTERVAL
from sqlalchemy.exc import IntegrityError
Expand Down Expand Up @@ -1465,3 +1467,37 @@ def remove_lineage_relations(self,
qry = qry.where(DatasetLineage.source_dataset_ref.in_(ids))
results = self._connection.execute(qry)
return results.rowcount

def temporal_extent_by_prod(self, product_id: int) -> tuple[datetime.datetime, datetime.datetime]:
query = self.temporal_extent_full().where(Dataset.product_ref == product_id)
res = self._connection.execute(query)
return res.first()

def temporal_extent_by_ids(self, ids: Iterable[DSID]) -> tuple[datetime.datetime, datetime.datetime]:
query = self.temporal_extent_full().where(Dataset.id.in_(ids))
res = self._connection.execute(query)
return res.first()

def temporal_extent_full(self) -> Select:
# Hardcode eo3 standard time locations - do not use this approach in a legacy index driver.
time_min = DateDocField('aquisition_time_min',
'Min of time when dataset was acquired',
Dataset.metadata_doc,
False, # is it indexed
offset=[
['properties', 'dtr:start_datetime'],
['properties', 'datetime']
],
selection='least')
time_max = DateDocField('aquisition_time_max',
'Max of time when dataset was acquired',
Dataset.metadata_doc,
False, # is it indexed
offset=[
['properties', 'dtr:end_datetime'],
['properties', 'datetime']
],
selection='greatest')
return select(
func.min(time_min.alchemy_expression), func.max(time_max.alchemy_expression)
)
62 changes: 58 additions & 4 deletions datacube/index/abstract.py
Expand Up @@ -1062,14 +1062,38 @@ def __init__(self, index):
self.types = self.products # types is compatibility alias for products

@abstractmethod
def get_unsafe(self,
id_: DSID,
include_sources: bool = False,
include_deriveds: bool = False,
max_depth: int = 0
) -> Dataset:
"""
Get dataset by id (Raises KeyError if id_ does not exist)
- Index drivers supporting the legacy lineage API:
:param id_: id of the dataset to retrieve
:param include_sources: include the full provenance tree of the dataset.
- Index drivers supporting the external lineage API:
:param id_: id of the dataset to retrieve
:param include_sources: include the full provenance tree for the dataset.
:param include_deriveds: include the full derivative tree for the dataset.
:param max_depth: The maximum depth of the source and/or derived tree. Defaults to 0, meaning no limit.
:rtype: Dataset model (None if not found)
"""

def get(self,
id_: DSID,
include_sources: bool = False,
include_deriveds: bool = False,
max_depth: int = 0
) -> Optional[Dataset]:
"""
Get dataset by id
Get dataset by id (Return None if id_ does not exist.
- Index drivers supporting the legacy lineage API:
Expand All @@ -1085,6 +1109,10 @@ def get(self,
:param max_depth: The maximum depth of the source and/or derived tree. Defaults to 0, meaning no limit.
:rtype: Dataset model (None if not found)
"""
try:
return self.get_unsafe(id_, include_sources, include_deriveds, max_depth)
except KeyError:
return None

def _check_get_legacy(self,
include_deriveds: bool = False,
Expand Down Expand Up @@ -1112,6 +1140,10 @@ def bulk_get(self, ids: Iterable[DSID]) -> Iterable[Dataset]:
:return: Iterable of Dataset models
"""

@deprecat(
reason="The 'get_derived' static method is deprecated in favour of the new lineage API.",
version='1.9.0',
category=ODC2DeprecationWarning)
@abstractmethod
def get_derived(self, id_: DSID) -> Iterable[Dataset]:
"""
Expand Down Expand Up @@ -1664,15 +1696,37 @@ def search_eager(self, **query: QueryField) -> List[Dataset]:
return list(self.search(**query)) # type: ignore[arg-type] # mypy isn't being very smart here :(

@abstractmethod
def temporal_extent(self,
product: str | Product | None,
ids: Iterable[DSID] | None
) -> tuple[datetime.datetime, datetime.datetime]:
"""
Returns the minimum and maximum acquisition time of a product or an iterable of dataset ids.
Only one ids or products can be passed - the other should be None. Raises ValueError if
both or neither of ids and products is passed. Raises KeyError if no datasets in the index
match the input argument.
:param product: Product or name of product
:param ids: Iterable of dataset ids.
:return: minimum and maximum acquisition times
"""

@deprecat(
reason="This method has been renamed 'temporal_extent'",
version="1.9.0",
category=ODC2DeprecationWarning
)
def get_product_time_bounds(self,
product: str
) -> Tuple[datetime.datetime, datetime.datetime]:
product: str | Product
) -> tuple[datetime.datetime, datetime.datetime]:
"""
Returns the minimum and maximum acquisition time of the product.
:param product: Name of product
:param product: Product of name of product
:return: minimum and maximum acquisition times
"""
return self.temporal_extent(product=product)

@abstractmethod
def search_returning_datasets_light(self,
Expand Down
45 changes: 26 additions & 19 deletions datacube/index/memory/_datasets.py
Expand Up @@ -47,19 +47,16 @@ def __init__(self, index: AbstractIndex) -> None:
# Active Index By Product
self.by_product: MutableMapping[str, List[UUID]] = {}

def get(self, id_: DSID, include_sources: bool = False,
include_deriveds: bool = False, max_depth: int = 0) -> Optional[Dataset]:
def get_unsafe(self, id_: DSID, include_sources: bool = False,
include_deriveds: bool = False, max_depth: int = 0) -> Dataset:
self._check_get_legacy(include_deriveds, max_depth)
try:
ds = self.clone(self.by_id[dsid_to_uuid(id_)])
if include_sources:
ds.sources = {
classifier: cast(Dataset, self.get(dsid, include_sources=True))
for classifier, dsid in self.derived_from.get(ds.id, {}).items()
}
return ds
except KeyError:
return None
ds = self.clone(self.by_id[dsid_to_uuid(id_)]) # N.B. raises KeyError if id not in index.
if include_sources:
ds.sources = {
classifier: cast(Dataset, self.get(dsid, include_sources=True))
for classifier, dsid in self.derived_from.get(ds.id, {}).items()
}
return ds

def bulk_get(self, ids: Iterable[DSID]) -> Iterable[Dataset]:
return (ds for ds in (self.get(dsid) for dsid in ids) if ds is not None)
Expand Down Expand Up @@ -645,15 +642,25 @@ def make_summary(ds: Dataset) -> Mapping[str, Any]:
for ds in self.search(**query): # type: ignore[arg-type]
yield make_summary(ds)

def get_product_time_bounds(self, product: str) -> Tuple[datetime.datetime, datetime.datetime]:
def temporal_extent(
self,
product: str | Product | None = None,
ids: Iterable[DSID] | None = None
) -> tuple[datetime.datetime, datetime.datetime]:
if product is None and ids is None:
raise ValueError("Must supply product or ids")
elif product is not None and ids is not None:
raise ValueError("Cannot supply both product and ids")
elif product is not None:
if isinstance(product, str):
product = self._index.products.get_by_name_unsafe(product)
ids = self.by_product.get(product.name, [])

min_time: Optional[datetime.datetime] = None
max_time: Optional[datetime.datetime] = None
prod = self._index.products.get_by_name(product)
if prod is None:
raise ValueError(f"Product {product} not in index")
time_fld = prod.metadata_type.dataset_fields["time"]
for dsid in self.by_product.get(product, []):
ds = cast(Dataset, self.get(dsid))
for dsid in ids:
ds = self.get_unsafe(dsid)
time_fld = ds.product.metadata_type.dataset_fields["time"]
dsmin, dsmax = time_fld.extract(ds.metadata_doc) # type: ignore[attr-defined]
if dsmax is None and dsmin is None:
continue
Expand Down
21 changes: 17 additions & 4 deletions datacube/index/null/_datasets.py
Expand Up @@ -3,6 +3,8 @@
# Copyright (c) 2015-2024 ODC Contributors
# SPDX-License-Identifier: Apache-2.0

import datetime

from datacube.index.abstract import AbstractDatasetResource, DSID
from datacube.model import Dataset, Product
from typing import Iterable, Optional
Expand All @@ -12,8 +14,8 @@ class DatasetResource(AbstractDatasetResource):
def __init__(self, index):
super().__init__(index)

def get(self, id_: DSID, include_sources: bool = False, include_deriveds: bool = False, max_depth: int = 0):
return None
def get_unsafe(self, id_: DSID, include_sources: bool = False, include_deriveds: bool = False, max_depth: int = 0):
raise KeyError(id_)

def bulk_get(self, ids):
return []
Expand Down Expand Up @@ -104,8 +106,19 @@ def count_product_through_time(self, period, **query):
def search_summaries(self, **query):
return []

def get_product_time_bounds(self, product: str):
raise NotImplementedError()
def temporal_extent(
self,
product: str | Product = None,
ids: Iterable[DSID] | None = None
) -> tuple[datetime.datetime, datetime.datetime]:
if product is None and ids is None:
raise ValueError("Must specify product or ids")
elif ids is not None and product is not None:
raise ValueError("Cannot specify both product and ids")
elif ids is not None:
raise KeyError(str(ids))
else:
raise KeyError(str(product))

# pylint: disable=redefined-outer-name
def search_returning_datasets_light(self, field_names: tuple, custom_offsets=None, limit=None, **query):
Expand Down
66 changes: 29 additions & 37 deletions datacube/index/postgis/_datasets.py
Expand Up @@ -5,6 +5,7 @@
"""
API for dataset indexing, access and search.
"""
import datetime
import json
import logging
import warnings
Expand All @@ -13,14 +14,14 @@
from typing import Iterable, List, Mapping, Union, Optional, Any
from uuid import UUID

from sqlalchemy import select, func
from deprecat import deprecat

from datacube.drivers.postgis._fields import SimpleDocField, DateDocField
from datacube.drivers.postgis._fields import SimpleDocField
from datacube.drivers.postgis._schema import Dataset as SQLDataset, search_field_map
from datacube.drivers.postgis._api import non_native_fields, extract_dataset_fields
from datacube.utils.uris import split_uri
from datacube.drivers.postgis._spatial import generate_dataset_spatial_values, extract_geometry_from_eo3_projection

from datacube.migration import ODC2DeprecationWarning
from datacube.index.abstract import AbstractDatasetResource, DatasetSpatialMixin, DSID, BatchStatus, DatasetTuple
from datacube.index.postgis._transaction import IndexResourceAddIn
from datacube.model import Dataset, Product, Range, LineageTree
Expand Down Expand Up @@ -51,10 +52,10 @@ def __init__(self, db, index):
self._db = db
super().__init__(index)

def get(self, id_: DSID,
include_sources: bool = False, include_deriveds: bool = False, max_depth: int = 0) -> Optional[Dataset]:
def get_unsafe(self, id_: DSID,
include_sources: bool = False, include_deriveds: bool = False, max_depth: int = 0) -> Dataset:
"""
Get dataset by id
Get dataset by id (raise KeyError if not found)
:param id_: id of the dataset to retrieve
:param include_sources: include the full provenance tree for the dataset.
Expand All @@ -74,7 +75,7 @@ def get(self, id_: DSID,
with self._db_connection() as connection:
dataset = connection.get_dataset(id_)
if not dataset:
return None
raise KeyError(id_)
return self._make(dataset, full_info=True, source_tree=source_tree, derived_tree=derived_tree)

def bulk_get(self, ids):
Expand All @@ -87,6 +88,10 @@ def to_uuid(x):
rows = connection.get_datasets(ids)
return [self._make(r, full_info=True) for r in rows]

@deprecat(
reason="The 'get_derived' static method is deprecated in favour of the new lineage API.",
version='1.9.0',
category=ODC2DeprecationWarning)
def get_derived(self, id_):
"""
Get all derived datasets
Expand Down Expand Up @@ -742,39 +747,26 @@ def search_summaries(self, **query):
_LOG.warning("search results: %s (%s)", output["id"], output["product"])
yield output

def get_product_time_bounds(self, product: str):
def temporal_extent(
self,
product: str | Product | None = None,
ids: Iterable[DSID] | None = None
) -> tuple[datetime.datetime, datetime.datetime]:
"""
Returns the minimum and maximum acquisition time of the product.
"""

# Get the offsets from dataset doc
product = self.products.get_by_name(product)
dataset_section = product.metadata_type.definition['dataset']
min_offset = dataset_section['search_fields']['time']['min_offset']
max_offset = dataset_section['search_fields']['time']['max_offset']

time_min = DateDocField('aquisition_time_min',
'Min of time when dataset was acquired',
SQLDataset.metadata_doc,
False, # is it indexed
offset=min_offset,
selection='least')

time_max = DateDocField('aquisition_time_max',
'Max of time when dataset was acquired',
SQLDataset.metadata_doc,
False, # is it indexed
offset=max_offset,
selection='greatest')

with self._db_connection() as connection:
result = connection.execute(
select(
[func.min(time_min.alchemy_expression), func.max(time_max.alchemy_expression)]
).where(
SQLDataset.product_ref == product.id
)
).first()
if product is None and ids is None:
raise ValueError("Must supply product or ids")
elif product is not None and ids is not None:
raise ValueError("Cannot supply both product and ids")
elif product is not None:
if isinstance(product, str):
product = self._index.products.get_by_name_unsafe(product)
with self._db_connection() as connection:
result = connection.temporal_extent_by_prod(product.id)
else:
with self._db_connection() as connection:
result = connection.temporal_extent_by_ids(ids)

return result

Expand Down

0 comments on commit ad78e75

Please sign in to comment.