Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-37938: additional spatial constraint query fixes and testing #787

Merged
merged 5 commits into from
Feb 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/DM-37938.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix additional bugs in spatial query constraints introduced in DM-31725.
91 changes: 42 additions & 49 deletions python/lsst/daf/butler/registry/queries/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

__all__ = ("QueryBuilder",)

from typing import Any, cast
import itertools
from typing import Any

from lsst.daf.relation import ColumnExpression, ColumnTag, Diagnostics, Predicate, Relation
from lsst.daf.relation import ColumnExpression, ColumnTag, Diagnostics, Relation

from ...core import (
ColumnCategorization,
Expand Down Expand Up @@ -127,9 +128,7 @@ def joinDataset(
collections,
governor_constraints=self._governor_constraints,
rejections=rejections,
allow_calibration_collections=(
not findFirst and not (self.summary.temporal or self.summary.dimensions.temporal)
),
allow_calibration_collections=(not findFirst and not self.summary.dimensions.temporal),
)
columns_requested = {"dataset_id", "run", "ingest_date"} if isResult else frozenset()
if not collection_records:
Expand Down Expand Up @@ -168,41 +167,39 @@ def _addWhereClause(self, categorized_columns: ColumnCategorization) -> None:
`ColumnTag` type.
"""
# Append WHERE clause terms from predicates.
predicate: Predicate = Predicate.literal(True)
if self.summary.where.expression_predicate is not None:
predicate = predicate.logical_and(self.summary.where.expression_predicate)
self.relation = self.relation.with_rows_satisfying(
self.summary.where.expression_predicate,
preferred_engine=self._context.preferred_engine,
require_preferred_engine=True,
)
if self.summary.where.data_id:
known_dimensions = self.summary.where.data_id.graph.intersection(self.summary.dimensions)
known_data_id = self.summary.where.data_id.subset(known_dimensions)
predicate = predicate.logical_and(self._context.make_data_coordinate_predicate(known_data_id))
if self.summary.where.region is not None:
self.relation = self.relation.with_rows_satisfying(
self._context.make_data_coordinate_predicate(known_data_id),
preferred_engine=self._context.preferred_engine,
require_preferred_engine=True,
)
if self.summary.region is not None:
for skypix_dimension in categorized_columns.filter_skypix(self._backend.universe):
if skypix_dimension not in self.summary.where.data_id.graph:
predicate = predicate.logical_and(
self._context.make_spatial_region_skypix_predicate(
skypix_dimension,
self.summary.where.region,
)
)
self.relation = self.relation.with_rows_satisfying(
self._context.make_spatial_region_skypix_predicate(
skypix_dimension,
self.summary.region,
),
preferred_engine=self._context.preferred_engine,
require_preferred_engine=True,
)
for element in categorized_columns.filter_spatial_region_dimension_elements():
if element not in self.summary.where.data_id.graph.names:
predicate = predicate.logical_and(
self._context.make_spatial_region_overlap_predicate(
ColumnExpression.reference(DimensionRecordColumnTag(element, "region")),
ColumnExpression.literal(self.summary.where.region),
)
)
if self.summary.where.timespan is not None:
for element in categorized_columns.filter_timespan_dimension_elements():
if element not in self.summary.where.data_id.graph.names:
predicate = predicate.logical_and(
self._context.make_timespan_overlap_predicate(
DimensionRecordColumnTag(element, "timespan"), self.summary.where.timespan
)
)
self.relation = self.relation.with_rows_satisfying(
predicate, preferred_engine=self._context.preferred_engine, require_preferred_engine=True
)
self.relation = self.relation.with_rows_satisfying(
self._context.make_spatial_region_overlap_predicate(
ColumnExpression.reference(DimensionRecordColumnTag(element, "region")),
ColumnExpression.literal(self.summary.region),
),
preferred_engine=self._context.iteration_engine,
transfer=True,
)

def finish(self, joinMissing: bool = True) -> Query:
"""Finish query constructing, returning a new `Query` instance.
Expand All @@ -223,28 +220,24 @@ def finish(self, joinMissing: bool = True) -> Query:
A `Query` object that can be executed and used to interpret result
rows.
"""
columns_required: set[ColumnTag] = set()
if self.summary.where.expression_predicate is not None:
columns_required.update(self.summary.where.expression_predicate.columns_required)
if self.summary.order_by is not None:
columns_required.update(self.summary.order_by.columns_required)
columns_required.update(DimensionKeyColumnTag.generate(self.summary.requested.names))
if self.summary.universe.commonSkyPix in self.summary.spatial:
columns_required.add(DimensionKeyColumnTag(self.summary.universe.commonSkyPix.name))
if joinMissing:
spatial_joins = []
for family1, family2 in itertools.combinations(self.summary.dimensions.spatial, 2):
spatial_joins.append(
(
family1.choose(self.summary.dimensions.elements).name,
family2.choose(self.summary.dimensions.elements).name,
)
)
self.relation = self._backend.make_dimension_relation(
self.summary.dimensions,
columns=columns_required,
columns=self.summary.columns_required,
context=self._context,
spatial_joins=(
[cast(tuple[str, str], tuple(self.summary.spatial.names))]
if len(self.summary.spatial) == 2
else []
),
spatial_joins=spatial_joins,
initial_relation=self.relation,
governor_constraints=self._governor_constraints,
)
categorized_columns = ColumnCategorization.from_iterable(columns_required)
categorized_columns = ColumnCategorization.from_iterable(self.relation.columns)
self._addWhereClause(categorized_columns)
query = Query(
self.summary.dimensions,
Expand Down
141 changes: 52 additions & 89 deletions python/lsst/daf/butler/registry/queries/_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import astropy.time
from lsst.daf.relation import ColumnExpression, ColumnTag, Predicate, SortTerm
from lsst.sphgeom import Region
from lsst.sphgeom import IntersectionRegion, Region
from lsst.utils.classes import cached_getter, immutable

from ...core import (
Expand All @@ -42,7 +42,6 @@
NamedValueAbstractSet,
NamedValueSet,
SkyPixDimension,
Timespan,
)

# We're not trying to add typing to the lex/yacc parser code, so MyPy
Expand All @@ -69,7 +68,6 @@ def combine(
bind: Mapping[str, Any] | None = None,
data_id: DataCoordinate | None = None,
region: Region | None = None,
timespan: Timespan | None = None,
defaults: DataCoordinate | None = None,
dataset_type_name: str | None = None,
allow_orphans: bool = False,
Expand All @@ -90,13 +88,7 @@ def combine(
A data ID identifying dimensions known in advance. If not
provided, will be set to an empty data ID.
region : `lsst.sphgeom.Region`, optional
A spatial constraint that all rows must overlap. If `None` and
``data_id`` is an expanded data ID, ``data_id.region`` will be used
to construct one.
timespan : `Timespan`, optional
A temporal constraint that all rows must overlap. If `None` and
``data_id`` is an expanded data ID, ``data_id.timespan`` will be
used to construct one.
A spatial constraint that all rows must overlap.
defaults : `DataCoordinate`, optional
A data ID containing default for governor dimensions.
dataset_type_name : `str` or `None`, optional
Expand All @@ -113,11 +105,6 @@ def combine(
where : `QueryWhereClause`
An object representing the WHERE clause for a query.
"""
if data_id is not None and data_id.hasRecords():
if region is None and data_id.region is not None:
region = data_id.region
if timespan is None and data_id.timespan is not None:
timespan = data_id.timespan
if data_id is None:
data_id = DataCoordinate.makeEmpty(dimensions.universe)
if defaults is None:
Expand All @@ -135,7 +122,6 @@ def combine(
expression_predicate,
data_id,
region=region,
timespan=timespan,
governor_constraints=governor_constraints,
)

Expand All @@ -154,11 +140,6 @@ def combine(
(`lsst.sphgeom.Region` or `None`).
"""

timespan: Timespan | None
"""A temporal constraint that all result rows must overlap
(`Timespan` or `None`).
"""

governor_constraints: Mapping[str, Set[str]]
"""Restrictions on the values governor dimensions can take in this query,
imposed by the string expression and/or data ID
Expand Down Expand Up @@ -349,12 +330,9 @@ class QuerySummary:
expression : `str`, optional
A user-provided string WHERE expression.
region : `lsst.sphgeom.Region`, optional
If `None` and ``data_id`` is an expanded data ID, ``data_id.region``
will be used to construct one.
A spatial constraint that all rows must overlap.
timespan : `Timespan`, optional
A temporal constraint that all rows must overlap. If `None` and
``data_id`` is an expanded data ID, ``data_id.timespan`` will be used
to construct one.
A temporal constraint that all rows must overlap.
bind : `Mapping` [ `str`, `object` ], optional
Mapping containing literal values that should be injected into the
query expression, keyed by the identifiers they replace.
Expand Down Expand Up @@ -384,7 +362,6 @@ def __init__(
data_id: DataCoordinate | None = None,
expression: str = "",
region: Region | None = None,
timespan: Timespan | None = None,
bind: Mapping[str, Any] | None = None,
defaults: DataCoordinate | None = None,
datasets: Iterable[DatasetType] = (),
Expand All @@ -404,14 +381,13 @@ def __init__(
bind=bind,
data_id=data_id,
region=region,
timespan=timespan,
defaults=defaults,
dataset_type_name=dataset_type_name,
allow_orphans=not check,
)
self.order_by = None if order_by is None else OrderByClause.parse_general(order_by, requested)
self.limit = limit
self.columns_required, self.dimensions = self._compute_columns_required()
self.columns_required, self.dimensions, self.region = self._compute_columns_required()

requested: DimensionGraph
"""Dimensions whose primary keys should be included in the result rows of
Expand Down Expand Up @@ -442,6 +418,15 @@ def __init__(
"""All dimensions in the query in any form (`DimensionGraph`).
"""

region: Region | None
"""Region that bounds all query results (`lsst.sphgeom.Region`).

While `QueryWhereClause.region` and the ``region`` constructor argument
represent an external region given directly by the caller, this represents
the region actually used directly as a constraint on the query results,
which can also come from the data ID passed by the caller.
"""

columns_required: Set[ColumnTag]
"""All columns that must be included directly in the query.

Expand All @@ -454,57 +439,9 @@ def universe(self) -> DimensionUniverse:
"""All known dimensions (`DimensionUniverse`)."""
return self.requested.universe

@property
@cached_getter
def spatial(self) -> NamedValueAbstractSet[DimensionElement]:
"""Dimension elements whose regions and skypix IDs should be included
in the query (`NamedValueAbstractSet` of `DimensionElement`).
"""
# An element may participate spatially in the query if:
# - it's the most precise spatial element for its system in the
# requested dimensions (i.e. in `self.requested.spatial`);
# - it isn't also given at query construction time.
result: NamedValueSet[DimensionElement] = NamedValueSet()
for family in self.dimensions.spatial:
element = family.choose(self.dimensions.elements)
assert isinstance(element, DimensionElement)
if element not in self.where.data_id.graph.elements:
result.add(element)
if len(result) == 1:
# There's no spatial join, but there might be a WHERE filter based
# on a given region.
if self.where.data_id.graph.spatial:
# We can only perform those filters against SkyPix dimensions,
# so if what we have isn't one, add the common SkyPix dimension
# to the query; the element we have will be joined to that.
(element,) = result
if not isinstance(element, SkyPixDimension):
result.add(self.universe.commonSkyPix)
else:
# There is no spatial join or filter in this query. Even
# if this element might be associated with spatial
# information, we don't need it for this query.
return NamedValueSet().freeze()
return result.freeze()

@property
@cached_getter
def temporal(self) -> NamedValueAbstractSet[DimensionElement]:
"""Dimension elements whose timespans should be included in the
query (`NamedValueSet` of `DimensionElement`).
"""
if len(self.dimensions.temporal) > 1:
# We don't actually have multiple temporal families in our current
# dimension configuration, so this limitation should be harmless.
raise NotImplementedError("Queries that should involve temporal joins are not yet supported.")
result = NamedValueSet[DimensionElement]()
if self.where.expression_predicate is not None:
for tag in DimensionRecordColumnTag.filter_from(self.where.expression_predicate.columns_required):
if tag.column == "timespan":
result.add(self.requested.universe[tag.element])
return result.freeze()

def _compute_columns_required(self) -> tuple[set[ColumnTag], DimensionGraph]:
def _compute_columns_required(
self,
) -> tuple[set[ColumnTag], DimensionGraph, Region | None]:
"""Compute the columns that must be provided by the relations joined
into this query in order to obtain the right *set* of result rows in
the right order.
Expand All @@ -513,24 +450,50 @@ def _compute_columns_required(self) -> tuple[set[ColumnTag], DimensionGraph]:
result rows, and hence could be provided by postprocessors.
"""
tags: set[ColumnTag] = set(DimensionKeyColumnTag.generate(self.requested.names))
tags.update(
DimensionKeyColumnTag.generate(
dimension.name
for dimension in self.where.data_id.graph
if dimension == self.requested.universe.commonSkyPix
or not isinstance(dimension, SkyPixDimension)
)
)
for dataset_type in self.datasets:
tags.update(DimensionKeyColumnTag.generate(dataset_type.dimensions.names))
if self.where.expression_predicate is not None:
tags.update(self.where.expression_predicate.columns_required)
if self.order_by is not None:
tags.update(self.order_by.columns_required)
region = self.where.region
for dimension in self.where.data_id.graph:
dimension_tag = DimensionKeyColumnTag(dimension.name)
if dimension_tag in tags:
continue
if dimension == self.universe.commonSkyPix or not isinstance(dimension, SkyPixDimension):
# If a dimension in the data ID is available from dimension
# tables or dimension spatial-join tables in the database,
# include it in the set of dimensions whose tables should be
# joined. This makes these data ID constraints work just like
# simple 'where' constraints, which is good.
tags.add(dimension_tag)
else:
# This is a SkyPixDimension other than the common one. If it's
# not already present in the query (e.g. from a dataset join),
# this is a pure spatial constraint, which we can only apply by
# modifying the 'region' for the query. That will also require
# that we join in the common skypix dimension.
pixel = dimension.pixelization.pixel(self.where.data_id[dimension])
if region is None:
region = pixel
else:
region = IntersectionRegion(region, pixel)
# Make sure the dimension keys are expanded self-consistently in what
# we return by passing them through DimensionGraph.
dimensions = DimensionGraph(
self.universe, names={tag.dimension for tag in DimensionKeyColumnTag.filter_from(tags)}
)
# If we have a region constraint, ensure region columns and the common
# skypix dimension are included.
missing_common_skypix = False
if region is not None:
for family in dimensions.spatial:
element = family.choose(dimensions.elements)
tags.add(DimensionRecordColumnTag(element.name, "region"))
if not isinstance(element, SkyPixDimension) and self.universe.commonSkyPix not in dimensions:
missing_common_skypix = True
if missing_common_skypix:
dimensions = dimensions.union(self.universe.commonSkyPix.graph)
tags.update(DimensionKeyColumnTag.generate(dimensions.names))
return (tags, dimensions)
return (tags, dimensions, region)
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/registry/tests/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3212,7 +3212,7 @@ def test_skypix_constraint_queries(self) -> None:
# Gather all skypix IDs that definitely overlap at least one of these
# patches.
relevant_skypix_ids = lsst.sphgeom.RangeSet()
for patch_key, patch_region in patch_regions.items():
for patch_region in patch_regions.values():
relevant_skypix_ids |= skypix_dimension.pixelization.interior(patch_region)
# Look for a "nontrivial" skypix_id that overlaps at least one patch
# and does not overlap at least one other patch.
Expand Down