Skip to content

Commit

Permalink
Merge pull request #787 from lsst/tickets/DM-37938
Browse files Browse the repository at this point in the history
DM-37938: additional spatial constraint query fixes and testing
  • Loading branch information
TallJimbo committed Feb 11, 2023
2 parents 460accc + 80c95a8 commit c645d68
Show file tree
Hide file tree
Showing 5 changed files with 443 additions and 139 deletions.
1 change: 1 addition & 0 deletions doc/changes/DM-37938.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix additional bugs in spatial query constraints introduced in DM-31725.
91 changes: 42 additions & 49 deletions python/lsst/daf/butler/registry/queries/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

__all__ = ("QueryBuilder",)

from typing import Any, cast
import itertools
from typing import Any

from lsst.daf.relation import ColumnExpression, ColumnTag, Diagnostics, Predicate, Relation
from lsst.daf.relation import ColumnExpression, ColumnTag, Diagnostics, Relation

from ...core import (
ColumnCategorization,
Expand Down Expand Up @@ -127,9 +128,7 @@ def joinDataset(
collections,
governor_constraints=self._governor_constraints,
rejections=rejections,
allow_calibration_collections=(
not findFirst and not (self.summary.temporal or self.summary.dimensions.temporal)
),
allow_calibration_collections=(not findFirst and not self.summary.dimensions.temporal),
)
columns_requested = {"dataset_id", "run", "ingest_date"} if isResult else frozenset()
if not collection_records:
Expand Down Expand Up @@ -168,41 +167,39 @@ def _addWhereClause(self, categorized_columns: ColumnCategorization) -> None:
`ColumnTag` type.
"""
# Append WHERE clause terms from predicates.
predicate: Predicate = Predicate.literal(True)
if self.summary.where.expression_predicate is not None:
predicate = predicate.logical_and(self.summary.where.expression_predicate)
self.relation = self.relation.with_rows_satisfying(
self.summary.where.expression_predicate,
preferred_engine=self._context.preferred_engine,
require_preferred_engine=True,
)
if self.summary.where.data_id:
known_dimensions = self.summary.where.data_id.graph.intersection(self.summary.dimensions)
known_data_id = self.summary.where.data_id.subset(known_dimensions)
predicate = predicate.logical_and(self._context.make_data_coordinate_predicate(known_data_id))
if self.summary.where.region is not None:
self.relation = self.relation.with_rows_satisfying(
self._context.make_data_coordinate_predicate(known_data_id),
preferred_engine=self._context.preferred_engine,
require_preferred_engine=True,
)
if self.summary.region is not None:
for skypix_dimension in categorized_columns.filter_skypix(self._backend.universe):
if skypix_dimension not in self.summary.where.data_id.graph:
predicate = predicate.logical_and(
self._context.make_spatial_region_skypix_predicate(
skypix_dimension,
self.summary.where.region,
)
)
self.relation = self.relation.with_rows_satisfying(
self._context.make_spatial_region_skypix_predicate(
skypix_dimension,
self.summary.region,
),
preferred_engine=self._context.preferred_engine,
require_preferred_engine=True,
)
for element in categorized_columns.filter_spatial_region_dimension_elements():
if element not in self.summary.where.data_id.graph.names:
predicate = predicate.logical_and(
self._context.make_spatial_region_overlap_predicate(
ColumnExpression.reference(DimensionRecordColumnTag(element, "region")),
ColumnExpression.literal(self.summary.where.region),
)
)
if self.summary.where.timespan is not None:
for element in categorized_columns.filter_timespan_dimension_elements():
if element not in self.summary.where.data_id.graph.names:
predicate = predicate.logical_and(
self._context.make_timespan_overlap_predicate(
DimensionRecordColumnTag(element, "timespan"), self.summary.where.timespan
)
)
self.relation = self.relation.with_rows_satisfying(
predicate, preferred_engine=self._context.preferred_engine, require_preferred_engine=True
)
self.relation = self.relation.with_rows_satisfying(
self._context.make_spatial_region_overlap_predicate(
ColumnExpression.reference(DimensionRecordColumnTag(element, "region")),
ColumnExpression.literal(self.summary.region),
),
preferred_engine=self._context.iteration_engine,
transfer=True,
)

def finish(self, joinMissing: bool = True) -> Query:
"""Finish query constructing, returning a new `Query` instance.
Expand All @@ -223,28 +220,24 @@ def finish(self, joinMissing: bool = True) -> Query:
A `Query` object that can be executed and used to interpret result
rows.
"""
columns_required: set[ColumnTag] = set()
if self.summary.where.expression_predicate is not None:
columns_required.update(self.summary.where.expression_predicate.columns_required)
if self.summary.order_by is not None:
columns_required.update(self.summary.order_by.columns_required)
columns_required.update(DimensionKeyColumnTag.generate(self.summary.requested.names))
if self.summary.universe.commonSkyPix in self.summary.spatial:
columns_required.add(DimensionKeyColumnTag(self.summary.universe.commonSkyPix.name))
if joinMissing:
spatial_joins = []
for family1, family2 in itertools.combinations(self.summary.dimensions.spatial, 2):
spatial_joins.append(
(
family1.choose(self.summary.dimensions.elements).name,
family2.choose(self.summary.dimensions.elements).name,
)
)
self.relation = self._backend.make_dimension_relation(
self.summary.dimensions,
columns=columns_required,
columns=self.summary.columns_required,
context=self._context,
spatial_joins=(
[cast(tuple[str, str], tuple(self.summary.spatial.names))]
if len(self.summary.spatial) == 2
else []
),
spatial_joins=spatial_joins,
initial_relation=self.relation,
governor_constraints=self._governor_constraints,
)
categorized_columns = ColumnCategorization.from_iterable(columns_required)
categorized_columns = ColumnCategorization.from_iterable(self.relation.columns)
self._addWhereClause(categorized_columns)
query = Query(
self.summary.dimensions,
Expand Down
141 changes: 52 additions & 89 deletions python/lsst/daf/butler/registry/queries/_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import astropy.time
from lsst.daf.relation import ColumnExpression, ColumnTag, Predicate, SortTerm
from lsst.sphgeom import Region
from lsst.sphgeom import IntersectionRegion, Region
from lsst.utils.classes import cached_getter, immutable

from ...core import (
Expand All @@ -42,7 +42,6 @@
NamedValueAbstractSet,
NamedValueSet,
SkyPixDimension,
Timespan,
)

# We're not trying to add typing to the lex/yacc parser code, so MyPy
Expand All @@ -69,7 +68,6 @@ def combine(
bind: Mapping[str, Any] | None = None,
data_id: DataCoordinate | None = None,
region: Region | None = None,
timespan: Timespan | None = None,
defaults: DataCoordinate | None = None,
dataset_type_name: str | None = None,
allow_orphans: bool = False,
Expand All @@ -90,13 +88,7 @@ def combine(
A data ID identifying dimensions known in advance. If not
provided, will be set to an empty data ID.
region : `lsst.sphgeom.Region`, optional
A spatial constraint that all rows must overlap. If `None` and
``data_id`` is an expanded data ID, ``data_id.region`` will be used
to construct one.
timespan : `Timespan`, optional
A temporal constraint that all rows must overlap. If `None` and
``data_id`` is an expanded data ID, ``data_id.timespan`` will be
used to construct one.
A spatial constraint that all rows must overlap.
defaults : `DataCoordinate`, optional
A data ID containing default for governor dimensions.
dataset_type_name : `str` or `None`, optional
Expand All @@ -113,11 +105,6 @@ def combine(
where : `QueryWhereClause`
An object representing the WHERE clause for a query.
"""
if data_id is not None and data_id.hasRecords():
if region is None and data_id.region is not None:
region = data_id.region
if timespan is None and data_id.timespan is not None:
timespan = data_id.timespan
if data_id is None:
data_id = DataCoordinate.makeEmpty(dimensions.universe)
if defaults is None:
Expand All @@ -135,7 +122,6 @@ def combine(
expression_predicate,
data_id,
region=region,
timespan=timespan,
governor_constraints=governor_constraints,
)

Expand All @@ -154,11 +140,6 @@ def combine(
(`lsst.sphgeom.Region` or `None`).
"""

timespan: Timespan | None
"""A temporal constraint that all result rows must overlap
(`Timespan` or `None`).
"""

governor_constraints: Mapping[str, Set[str]]
"""Restrictions on the values governor dimensions can take in this query,
imposed by the string expression and/or data ID
Expand Down Expand Up @@ -349,12 +330,9 @@ class QuerySummary:
expression : `str`, optional
A user-provided string WHERE expression.
region : `lsst.sphgeom.Region`, optional
If `None` and ``data_id`` is an expanded data ID, ``data_id.region``
will be used to construct one.
A spatial constraint that all rows must overlap.
timespan : `Timespan`, optional
A temporal constraint that all rows must overlap. If `None` and
``data_id`` is an expanded data ID, ``data_id.timespan`` will be used
to construct one.
A temporal constraint that all rows must overlap.
bind : `Mapping` [ `str`, `object` ], optional
Mapping containing literal values that should be injected into the
query expression, keyed by the identifiers they replace.
Expand Down Expand Up @@ -384,7 +362,6 @@ def __init__(
data_id: DataCoordinate | None = None,
expression: str = "",
region: Region | None = None,
timespan: Timespan | None = None,
bind: Mapping[str, Any] | None = None,
defaults: DataCoordinate | None = None,
datasets: Iterable[DatasetType] = (),
Expand All @@ -404,14 +381,13 @@ def __init__(
bind=bind,
data_id=data_id,
region=region,
timespan=timespan,
defaults=defaults,
dataset_type_name=dataset_type_name,
allow_orphans=not check,
)
self.order_by = None if order_by is None else OrderByClause.parse_general(order_by, requested)
self.limit = limit
self.columns_required, self.dimensions = self._compute_columns_required()
self.columns_required, self.dimensions, self.region = self._compute_columns_required()

requested: DimensionGraph
"""Dimensions whose primary keys should be included in the result rows of
Expand Down Expand Up @@ -442,6 +418,15 @@ def __init__(
"""All dimensions in the query in any form (`DimensionGraph`).
"""

region: Region | None
"""Region that bounds all query results (`lsst.sphgeom.Region`).
While `QueryWhereClause.region` and the ``region`` constructor argument
represent an external region given directly by the caller, this represents
the region actually used directly as a constraint on the query results,
which can also come from the data ID passed by the caller.
"""

columns_required: Set[ColumnTag]
"""All columns that must be included directly in the query.
Expand All @@ -454,57 +439,9 @@ def universe(self) -> DimensionUniverse:
"""All known dimensions (`DimensionUniverse`)."""
return self.requested.universe

@property
@cached_getter
def spatial(self) -> NamedValueAbstractSet[DimensionElement]:
"""Dimension elements whose regions and skypix IDs should be included
in the query (`NamedValueAbstractSet` of `DimensionElement`).
"""
# An element may participate spatially in the query if:
# - it's the most precise spatial element for its system in the
# requested dimensions (i.e. in `self.requested.spatial`);
# - it isn't also given at query construction time.
result: NamedValueSet[DimensionElement] = NamedValueSet()
for family in self.dimensions.spatial:
element = family.choose(self.dimensions.elements)
assert isinstance(element, DimensionElement)
if element not in self.where.data_id.graph.elements:
result.add(element)
if len(result) == 1:
# There's no spatial join, but there might be a WHERE filter based
# on a given region.
if self.where.data_id.graph.spatial:
# We can only perform those filters against SkyPix dimensions,
# so if what we have isn't one, add the common SkyPix dimension
# to the query; the element we have will be joined to that.
(element,) = result
if not isinstance(element, SkyPixDimension):
result.add(self.universe.commonSkyPix)
else:
# There is no spatial join or filter in this query. Even
# if this element might be associated with spatial
# information, we don't need it for this query.
return NamedValueSet().freeze()
return result.freeze()

@property
@cached_getter
def temporal(self) -> NamedValueAbstractSet[DimensionElement]:
"""Dimension elements whose timespans should be included in the
query (`NamedValueSet` of `DimensionElement`).
"""
if len(self.dimensions.temporal) > 1:
# We don't actually have multiple temporal families in our current
# dimension configuration, so this limitation should be harmless.
raise NotImplementedError("Queries that should involve temporal joins are not yet supported.")
result = NamedValueSet[DimensionElement]()
if self.where.expression_predicate is not None:
for tag in DimensionRecordColumnTag.filter_from(self.where.expression_predicate.columns_required):
if tag.column == "timespan":
result.add(self.requested.universe[tag.element])
return result.freeze()

def _compute_columns_required(self) -> tuple[set[ColumnTag], DimensionGraph]:
def _compute_columns_required(
self,
) -> tuple[set[ColumnTag], DimensionGraph, Region | None]:
"""Compute the columns that must be provided by the relations joined
into this query in order to obtain the right *set* of result rows in
the right order.
Expand All @@ -513,24 +450,50 @@ def _compute_columns_required(self) -> tuple[set[ColumnTag], DimensionGraph]:
result rows, and hence could be provided by postprocessors.
"""
tags: set[ColumnTag] = set(DimensionKeyColumnTag.generate(self.requested.names))
tags.update(
DimensionKeyColumnTag.generate(
dimension.name
for dimension in self.where.data_id.graph
if dimension == self.requested.universe.commonSkyPix
or not isinstance(dimension, SkyPixDimension)
)
)
for dataset_type in self.datasets:
tags.update(DimensionKeyColumnTag.generate(dataset_type.dimensions.names))
if self.where.expression_predicate is not None:
tags.update(self.where.expression_predicate.columns_required)
if self.order_by is not None:
tags.update(self.order_by.columns_required)
region = self.where.region
for dimension in self.where.data_id.graph:
dimension_tag = DimensionKeyColumnTag(dimension.name)
if dimension_tag in tags:
continue
if dimension == self.universe.commonSkyPix or not isinstance(dimension, SkyPixDimension):
# If a dimension in the data ID is available from dimension
# tables or dimension spatial-join tables in the database,
# include it in the set of dimensions whose tables should be
# joined. This makes these data ID constraints work just like
# simple 'where' constraints, which is good.
tags.add(dimension_tag)
else:
# This is a SkyPixDimension other than the common one. If it's
# not already present in the query (e.g. from a dataset join),
# this is a pure spatial constraint, which we can only apply by
# modifying the 'region' for the query. That will also require
# that we join in the common skypix dimension.
pixel = dimension.pixelization.pixel(self.where.data_id[dimension])
if region is None:
region = pixel
else:
region = IntersectionRegion(region, pixel)
# Make sure the dimension keys are expanded self-consistently in what
# we return by passing them through DimensionGraph.
dimensions = DimensionGraph(
self.universe, names={tag.dimension for tag in DimensionKeyColumnTag.filter_from(tags)}
)
# If we have a region constraint, ensure region columns and the common
# skypix dimension are included.
missing_common_skypix = False
if region is not None:
for family in dimensions.spatial:
element = family.choose(dimensions.elements)
tags.add(DimensionRecordColumnTag(element.name, "region"))
if not isinstance(element, SkyPixDimension) and self.universe.commonSkyPix not in dimensions:
missing_common_skypix = True
if missing_common_skypix:
dimensions = dimensions.union(self.universe.commonSkyPix.graph)
tags.update(DimensionKeyColumnTag.generate(dimensions.names))
return (tags, dimensions)
return (tags, dimensions, region)
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/registry/tests/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3212,7 +3212,7 @@ def test_skypix_constraint_queries(self) -> None:
# Gather all skypix IDs that definitely overlap at least one of these
# patches.
relevant_skypix_ids = lsst.sphgeom.RangeSet()
for patch_key, patch_region in patch_regions.items():
for patch_region in patch_regions.values():
relevant_skypix_ids |= skypix_dimension.pixelization.interior(patch_region)
# Look for a "nontrivial" skypix_id that overlaps at least one patch
# and does not overlap at least one other patch.
Expand Down

0 comments on commit c645d68

Please sign in to comment.