Skip to content

Commit

Permalink
Relax requirement on dimensions in followup dataset queries.
Browse files Browse the repository at this point in the history
This will allow reference catalog queries in QG generation to be
vectorized as long as they use the common skypix system.
  • Loading branch information
TallJimbo committed Aug 8, 2023
1 parent 2099e08 commit c2ecaa9
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 39 deletions.
102 changes: 73 additions & 29 deletions python/lsst/daf/butler/registry/queries/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

__all__ = ()

import itertools
from collections.abc import Iterable, Iterator, Mapping, Sequence, Set
from contextlib import contextmanager
from typing import Any, cast, final
Expand Down Expand Up @@ -648,9 +649,6 @@ def find_datasets(
lsst.daf.relation.ColumnError
Raised if a dataset search is already present in this query and
this is a find-first search.
ValueError
Raised if the given dataset type's dimensions are not a subset of
the current query's dimensions.
"""
if find_first and DatasetColumnTag.filter_from(self._relation.columns):
raise ColumnError(
Expand Down Expand Up @@ -680,14 +678,6 @@ def find_datasets(
# where we materialize the initial data ID query into a temp table
# and hence can't go back and "recover" those dataset columns anyway;
#
if not (dataset_type.dimensions <= self._dimensions):
raise ValueError(
"Cannot find datasets from a query unless the dataset types's dimensions "
f"({dataset_type.dimensions}, for {dataset_type.name}) are a subset of the query's "
f"({self._dimensions})."
)
columns = set(columns)
columns.add("dataset_id")
collections = CollectionWildcard.from_expression(collections)
if find_first:
collections.require_ordered()
Expand All @@ -699,27 +689,81 @@ def find_datasets(
allow_calibration_collections=True,
rejections=rejections,
)
# If the dataset type has dimensions not in the current query, or we
# need a temporal join for a calibration collection, either restore
# those columns or join them in.
full_dimensions = dataset_type.dimensions.union(self._dimensions)
relation = self._relation
record_caches = self._record_caches
base_columns_required: set[ColumnTag] = {
DimensionKeyColumnTag(name) for name in full_dimensions.names
}
spatial_joins: list[tuple[str, str]] = []
if not (dataset_type.dimensions <= self._dimensions):
if self._has_record_columns is True:
# This query is for expanded data IDs, so if we add new
# dimensions to the query we need to be able to get records for
# the new dimensions.
record_caches = dict(self._record_caches)

Check warning on line 707 in python/lsst/daf/butler/registry/queries/_query.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/registry/queries/_query.py#L707

Added line #L707 was not covered by tests
for element in self.dimensions.elements:
if element in record_caches:
continue

Check warning on line 710 in python/lsst/daf/butler/registry/queries/_query.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/registry/queries/_query.py#L710

Added line #L710 was not covered by tests
if (
cache := self._backend.get_dimension_record_cache(element.name, self._context)
) is not None:
record_caches[element] = cache

Check warning on line 714 in python/lsst/daf/butler/registry/queries/_query.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/registry/queries/_query.py#L714

Added line #L714 was not covered by tests
else:
base_columns_required.update(element.RecordClass.fields.columns.keys())

Check warning on line 716 in python/lsst/daf/butler/registry/queries/_query.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/registry/queries/_query.py#L716

Added line #L716 was not covered by tests
# See if we need spatial joins between the current query and the
# dataset type's dimensions. The logic here is for multiple
# spatial joins in general, but in practice it'll be exceedingly
# rare for there to be more than one. We start by figuring out
# which spatial "families" (observations vs. skymaps, skypix
# systems) are present on only one side and not the other.
lhs_spatial_families = self._dimensions.spatial - dataset_type.dimensions.spatial
rhs_spatial_families = dataset_type.dimensions.spatial - self._dimensions.spatial
# Now we iterate over the Cartesian product of those, so e.g.
# if the query has {tract, patch, visit} and the dataset type
# has {htm7} dimensions, the iterations of this loop
# correspond to: (skymap, htm), (observations, htm).
for lhs_spatial_family, rhs_spatial_family in itertools.product(
lhs_spatial_families, rhs_spatial_families
):
# For each pair we add a join between the most-precise element
# present in each family (e.g. patch beats tract).
spatial_joins.append(
(
lhs_spatial_family.choose(full_dimensions.elements).name,
rhs_spatial_family.choose(full_dimensions.elements).name,
)
)
# Set up any temporal join between the query dimensions and CALIBRATION
# collection's validity ranges.
temporal_join_on: set[ColumnTag] = set()
if any(r.type is CollectionType.CALIBRATION for r in collection_records):
for family in self._dimensions.temporal:
element = family.choose(self._dimensions.elements)
temporal_join_on.add(DimensionRecordColumnTag(element.name, "timespan"))
timespan_columns_required = set(temporal_join_on)
relation, columns_found = self._context.restore_columns(self._relation, timespan_columns_required)
timespan_columns_required.difference_update(columns_found)
if timespan_columns_required:
relation = self._backend.make_dimension_relation(
self._dimensions,
timespan_columns_required,
self._context,
initial_relation=relation,
# Don't permit joins to use any columns beyond those in the
# original relation, as that would change what this
# operation does.
initial_join_max_columns=frozenset(self._relation.columns),
governor_constraints=self._governor_constraints,
)
endpoint = family.choose(self._dimensions.elements)
temporal_join_on.add(DimensionRecordColumnTag(endpoint.name, "timespan"))
base_columns_required.update(temporal_join_on)
# Note which of the many kinds of potentially-missing columns we have
# and add the rest.
base_columns_required.difference_update(relation.columns)
if base_columns_required:
relation = self._backend.make_dimension_relation(
full_dimensions,
base_columns_required,
self._context,
initial_relation=relation,
# Don't permit joins to use any columns beyond those in the
# original relation, as that would change what this
# operation does.
initial_join_max_columns=frozenset(self._relation.columns),
governor_constraints=self._governor_constraints,
spatial_joins=spatial_joins,
)
# Finally we can join in the search for the dataset query.
columns = set(columns)
columns.add("dataset_id")
if not collection_records:
relation = relation.join(
self._backend.make_doomed_dataset_relation(dataset_type, columns, rejections, self._context)
Expand All @@ -742,7 +786,7 @@ def find_datasets(
join_to=relation,
temporal_join_on=temporal_join_on,
)
return self._chain(relation, defer=defer)
return self._chain(relation, dimensions=full_dimensions, record_caches=record_caches, defer=defer)

def sliced(
self,
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/registry/queries/_query_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

from lsst.daf.relation import (
BinaryOperationRelation,
ColumnTag,
ColumnExpression,
ColumnTag,
LeafRelation,
MarkerRelation,
Predicate,
Expand Down
7 changes: 2 additions & 5 deletions python/lsst/daf/butler/registry/queries/_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,6 @@ def findDatasets(
Raises
------
ValueError
Raised if ``datasetType.dimensions.issubset(self.graph) is False``.
MissingDatasetTypeError
Raised if the given dataset type is not registered.
"""
Expand Down Expand Up @@ -314,12 +312,11 @@ def findRelatedDatasets(
Raises
------
ValueError
Raised if ``datasetType.dimensions.issubset(self.graph) is False``
or ``dimensions.issubset(self.graph) is False``.
MissingDatasetTypeError
Raised if the given dataset type is not registered.
"""
if dimensions is None:
dimensions = self.graph
parent_dataset_type, _ = self._query.backend.resolve_single_dataset_type_wildcard(
datasetType, components=False, explicit_only=True
)
Expand Down
4 changes: 4 additions & 0 deletions python/lsst/daf/butler/registry/queries/_sql_query_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ def make_dimension_relation(
"it is part of a dataset subquery, spatial join, or other initial relation."
)

# Before joining in new tables to provide columns, attempt to restore
# them from the given relation by weakening projections applied to it.
relation, _ = context.restore_columns(relation, columns_required)

# Categorize columns not yet included in the relation to associate them
# with dimension elements and detect bad inputs.
missing_columns = ColumnCategorization.from_iterable(columns_required - relation.columns)
Expand Down
49 changes: 46 additions & 3 deletions python/lsst/daf/butler/registry/tests/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,9 +1489,12 @@ def testQueryResults(self):
expectedDeduplicatedBiases,
)

# Check dimensions match.
with self.assertRaises(ValueError):
subsetDataIds.findDatasets("flat", collections=["imported_r", "imported_g"], findFirst=True)
# Searching for a dataset with dimensions we had projected away
# restores those dimensions.
self.assertCountEqual(
list(subsetDataIds.findDatasets("flat", collections=["imported_r"], findFirst=True)),
expectedFlats,
)

# Use a component dataset type.
self.assertCountEqual(
Expand Down Expand Up @@ -3630,3 +3633,43 @@ def test_query_empty_collections(self) -> None:
messages = list(result.explain_no_results())
self.assertTrue(messages)
self.assertTrue(any("because collection list is empty" in message for message in messages))

def test_dataset_followup_spatial_joins(self) -> None:
"""Test queryDataIds(...).findRelatedDatasets(...) where a spatial join
is involved.
"""
registry = self.makeRegistry()
self.loadData(registry, "base.yaml")
self.loadData(registry, "spatial.yaml")
pvi_dataset_type = DatasetType(
"pvi", {"visit", "detector"}, storageClass="StructuredDataDict", universe=registry.dimensions
)
registry.registerDatasetType(pvi_dataset_type)
collection = "datasets"
registry.registerRun(collection)
(pvi1,) = registry.insertDatasets(
pvi_dataset_type, [{"instrument": "Cam1", "visit": 1, "detector": 1}], run=collection
)
(pvi2,) = registry.insertDatasets(
pvi_dataset_type, [{"instrument": "Cam1", "visit": 1, "detector": 2}], run=collection
)
(pvi3,) = registry.insertDatasets(
pvi_dataset_type, [{"instrument": "Cam1", "visit": 1, "detector": 3}], run=collection
)
self.assertEqual(
set(
registry.queryDataIds(["patch"], skymap="SkyMap1", tract=0).findRelatedDatasets(
"pvi", [collection]
)
),
{
(registry.expandDataId(skymap="SkyMap1", tract=0, patch=0), pvi1),
(registry.expandDataId(skymap="SkyMap1", tract=0, patch=0), pvi2),
(registry.expandDataId(skymap="SkyMap1", tract=0, patch=1), pvi2),
(registry.expandDataId(skymap="SkyMap1", tract=0, patch=2), pvi1),
(registry.expandDataId(skymap="SkyMap1", tract=0, patch=2), pvi2),
(registry.expandDataId(skymap="SkyMap1", tract=0, patch=2), pvi3),
(registry.expandDataId(skymap="SkyMap1", tract=0, patch=3), pvi2),
(registry.expandDataId(skymap="SkyMap1", tract=0, patch=4), pvi3),
},
)
2 changes: 1 addition & 1 deletion tests/data/registry/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def make_plots(detector_grid: bool, patch_grid: bool):
index_labels(color="black", alpha=0.5),
)
colors = iter(["red", "blue", "cyan", "green"])
for (visit_id, visit_data), color in zip(VISIT_DATA.items(), colors, strict=True):
for (visit_id, visit_data), color in zip(VISIT_DATA.items(), colors, strict=False):
for detector_id, pixel_indices in visit_data["detector_regions"].items():
label = f"visit={visit_id}"
if label in labels_used:
Expand Down

0 comments on commit c2ecaa9

Please sign in to comment.