Skip to content

Commit

Permalink
Merge pull request #670 from lsst/tickets/DM-34247
Browse files Browse the repository at this point in the history
DM-34247: simplify dataset subquery logic and fix edge-case bugs
  • Loading branch information
TallJimbo committed Apr 11, 2022
2 parents 6ac0126 + d168fa6 commit cbf5f61
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 187 deletions.
1 change: 1 addition & 0 deletions doc/changes/DM-34247.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix Registry.queryDataIds bug involving dataset constraints with no dimensions.
1 change: 1 addition & 0 deletions doc/changes/DM-34328.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix Registry.queryCollections bug in which children of chained collections were being alphabetically sorted instead of ordered consistently with in the order in which they would be searched.
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/registries/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ def queryDatasetAssociations(
flattenChains=flattenChains,
):
query = storage.select(collectionRecord)
for row in self._db.query(query.combine()).mappings():
for row in self._db.query(query).mappings():
dataId = DataCoordinate.fromRequiredValues(
storage.datasetType.dimensions,
tuple(row[name] for name in storage.datasetType.dimensions.required.names),
Expand Down
88 changes: 60 additions & 28 deletions python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__all__ = ("ByDimensionsDatasetRecordStorage",)

import uuid
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple

import sqlalchemy
from lsst.daf.butler import (
Expand Down Expand Up @@ -77,7 +77,6 @@ def find(
sql = self.select(
collection, dataId=dataId, id=SimpleQuery.Select, run=SimpleQuery.Select, timespan=timespan
)
sql = sql.combine()
results = self._db.query(sql)
row = results.fetchone()
if row is None:
Expand Down Expand Up @@ -324,12 +323,12 @@ def select(
run: SimpleQuery.Select.Or[None] = SimpleQuery.Select,
timespan: SimpleQuery.Select.Or[Optional[Timespan]] = SimpleQuery.Select,
ingestDate: SimpleQuery.Select.Or[Optional[Timespan]] = None,
) -> SimpleQuery:
) -> sqlalchemy.sql.Selectable:
# Docstring inherited from DatasetRecordStorage.
collection_types = {collection.type for collection in collections}
assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened."
#
# There are two tables in play here:
# There are two kinds of table in play here:
#
# - the static dataset table (with the dataset ID, dataset type ID,
# run ID/name, and ingest date);
Expand All @@ -353,12 +352,11 @@ def select(
# redundant columns in the JOIN ON expression, however, because the
# FOREIGN KEY (and its index) are defined only on dataset_id.
#
# We'll start with an empty SimpleQuery, and accumulate kwargs to pass
# to its `join` method when we bring in the tags/calibs table.
query = SimpleQuery()
# We get the data ID or constrain it in the tags/calibs table, but
# that's multiple columns, not one, so we need to transform the one
# Select.Or argument into a dictionary of them.
# We'll start by accumulating kwargs to pass to SimpleQuery.join when
# we bring in the tags/calibs table. We get the data ID or constrain
# it in the tags/calibs table(s), but that's multiple columns, not one,
# so we need to transform the one Select.Or argument into a dictionary
# of them.
kwargs: Dict[str, Any]
if dataId is SimpleQuery.Select:
kwargs = {dim.name: SimpleQuery.Select for dim in self.datasetType.dimensions.required}
Expand All @@ -367,9 +365,25 @@ def select(
# We always constrain (never retrieve) the dataset type in at least the
# tags/calibs table.
kwargs["dataset_type_id"] = self._dataset_type_id
# Join in the tags or calibs table, turning those 'kwargs' entries into
# WHERE constraints or SELECT columns as appropriate.
if collection_types == {CollectionType.CALIBRATION}:
# Join in the tags and/or calibs tables, turning those 'kwargs' entries
# into WHERE constraints or SELECT columns as appropriate.
if collection_types != {CollectionType.CALIBRATION}:
# We'll need a subquery for the tags table if any of the given
# collections are not a CALIBRATION collection. This intentionally
# also fires when the list of collections is empty as a way to
# create a dummy subquery that we know will fail.
tags_query = SimpleQuery()
tags_query.join(self._tags, **kwargs)
self._finish_single_select(
tags_query, self._tags, collections, id=id, run=run, ingestDate=ingestDate
)
else:
tags_query = None
if CollectionType.CALIBRATION in collection_types:
# If at least one collection is a CALIBRATION collection, we'll
# need a subquery for the calibs table, and could include the
# timespan as a result or constraint.
calibs_query = SimpleQuery()
assert (
self._calibs is not None
), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection."
Expand All @@ -379,23 +393,42 @@ def select(
if timespan is SimpleQuery.Select:
kwargs.update({k: SimpleQuery.Select for k in TimespanReprClass.getFieldNames()})
elif timespan is not None:
query.where.append(
calibs_query.where.append(
TimespanReprClass.fromSelectable(self._calibs).overlaps(
TimespanReprClass.fromLiteral(timespan)
)
)
query.join(self._calibs, **kwargs)
dataset_id_col = self._calibs.columns.dataset_id
collection_col = self._calibs.columns[self._collections.getCollectionForeignKeyName()]
elif CollectionType.CALIBRATION not in collection_types:
query.join(self._tags, **kwargs)
dataset_id_col = self._tags.columns.dataset_id
collection_col = self._tags.columns[self._collections.getCollectionForeignKeyName()]
else:
raise TypeError(
"Cannot query for CALIBRATION collections in the same "
"subquery as other kinds of collections."
calibs_query.join(self._calibs, **kwargs)
self._finish_single_select(
calibs_query, self._calibs, collections, id=id, run=run, ingestDate=ingestDate
)
else:
calibs_query = None
if calibs_query is not None:
if tags_query is not None:
if timespan is not None:
raise TypeError(
"Cannot query for timespan when the collections include both calibration and "
"non-calibration collections."
)
return tags_query.combine().union(calibs_query.combine())
else:
return calibs_query.combine()
else:
assert tags_query is not None, "Earlier logic should guaranteed at least one is not None."
return tags_query.combine()

def _finish_single_select(
self,
query: SimpleQuery,
table: sqlalchemy.schema.Table,
collections: Sequence[CollectionRecord],
id: SimpleQuery.Select.Or[Optional[int]],
run: SimpleQuery.Select.Or[None],
ingestDate: SimpleQuery.Select.Or[Optional[Timespan]],
) -> None:
dataset_id_col = table.columns.dataset_id
collection_col = table.columns[self._collections.getCollectionForeignKeyName()]
# We always constrain (never retrieve) the collection(s) in the
# tags/calibs table.
if len(collections) == 1:
Expand All @@ -405,8 +438,8 @@ def select(
# generate a valid SQL query that can't yield results. This should
# never get executed, but lots of downstream code will still try
# to access the SQLAlchemy objects representing the columns in the
# subquery. That's not idea, but it'd take a lot of refactoring to
# fix it.
# subquery. That's not ideal, but it'd take a lot of refactoring
# to fix it (DM-31725).
query.where.append(sqlalchemy.sql.literal(False))
else:
query.where.append(collection_col.in_([collection.key for collection in collections]))
Expand Down Expand Up @@ -464,7 +497,6 @@ def select(
# that that's a good idea IFF it's in the foreign key, and right
# now it isn't.
query.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id)
return query

def getDataId(self, id: DatasetId) -> DataCoordinate:
"""Return DataId for a dataset.
Expand Down
20 changes: 12 additions & 8 deletions python/lsst/daf/butler/registry/interfaces/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ def addTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table:
_checkExistingTableDefinition(
name, spec, self._inspector.get_columns(name, schema=self._db.namespace)
)
table = self._db._convertTableSpec(name, spec, self._db._metadata)
metadata = self._db._metadata
assert metadata is not None, "Guaranteed by context manager that returns this object."
table = self._db._convertTableSpec(name, spec, metadata)
for foreignKeySpec in spec.foreignKeys:
self._foreignKeys.append(
(table, self._db._convertForeignKeySpec(name, foreignKeySpec, self._db._metadata))
)
self._foreignKeys.append((table, self._db._convertForeignKeySpec(name, foreignKeySpec, metadata)))
return table

def addTableTuple(self, specs: Tuple[ddl.TableSpec, ...]) -> Tuple[sqlalchemy.schema.Table, ...]:
Expand Down Expand Up @@ -228,8 +228,11 @@ def makeTemporaryTable(self, spec: ddl.TableSpec, name: Optional[str] = None) ->
"""
if name is None:
name = f"tmp_{uuid.uuid4().hex}"
metadata = self._db._metadata
if metadata is None:
raise RuntimeError("Cannot create temporary table before static schema is defined.")
table = self._db._convertTableSpec(
name, spec, self._db._metadata, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA
name, spec, metadata, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA
)
if table.key in self._db._tempTables:
if table.key != name:
Expand All @@ -238,7 +241,7 @@ def makeTemporaryTable(self, spec: ddl.TableSpec, name: Optional[str] = None) ->
f"Database) already exists."
)
for foreignKeySpec in spec.foreignKeys:
table.append_constraint(self._db._convertForeignKeySpec(name, foreignKeySpec, self._db._metadata))
table.append_constraint(self._db._convertForeignKeySpec(name, foreignKeySpec, metadata))
with self._db._connection() as connection:
table.create(connection)
self._db._tempTables.add(table.key)
Expand Down Expand Up @@ -511,6 +514,7 @@ def transaction(
# `Connection.in_nested_transaction()` method.
savepoint = savepoint or connection.info.get(_IN_SAVEPOINT_TRANSACTION, False)
connection.info[_IN_SAVEPOINT_TRANSACTION] = savepoint
trans: sqlalchemy.engine.Transaction
if connection.in_transaction() and savepoint:
trans = connection.begin_nested()
elif not connection.in_transaction():
Expand Down Expand Up @@ -1676,13 +1680,13 @@ def update(self, table: sqlalchemy.schema.Table, where: Dict[str, str], *rows: d
return connection.execute(sql, rows).rowcount

def query(
self, sql: sqlalchemy.sql.FromClause, *args: Any, **kwargs: Any
self, sql: sqlalchemy.sql.Selectable, *args: Any, **kwargs: Any
) -> sqlalchemy.engine.ResultProxy:
"""Run a SELECT query against the database.
Parameters
----------
sql : `sqlalchemy.sql.FromClause`
sql : `sqlalchemy.sql.Selectable`
A SQLAlchemy representation of a ``SELECT`` query.
*args
Additional positional arguments are forwarded to
Expand Down
13 changes: 8 additions & 5 deletions python/lsst/daf/butler/registry/interfaces/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Optional, Tuple

import sqlalchemy.sql

from ...core import DataCoordinate, DatasetId, DatasetRef, DatasetType, SimpleQuery, Timespan, ddl
from ._versioning import VersionedExtension

Expand Down Expand Up @@ -318,7 +320,7 @@ def select(
run: SimpleQuery.Select.Or[None] = SimpleQuery.Select,
timespan: SimpleQuery.Select.Or[Optional[Timespan]] = SimpleQuery.Select,
ingestDate: SimpleQuery.Select.Or[Optional[Timespan]] = None,
) -> SimpleQuery:
) -> sqlalchemy.sql.Selectable:
"""Return a SQLAlchemy object that represents a ``SELECT`` query for
this `DatasetType`.
Expand Down Expand Up @@ -351,7 +353,9 @@ def select(
If `Select` (default), include the validity range timespan in the
result columns. If a `Timespan` instance, constrain the results to
those whose validity ranges overlap that given timespan. Ignored
unless ``collection.type is CollectionType.CALIBRATION``.
for collection types other than `~CollectionType.CALIBRATION``,
but `None` should be passed explicitly if a mix of
`~CollectionType.CALIBRATION` and other types are passed in.
ingestDate : `None`, `Select`, or `Timespan`
If `Select` include the ingest timestamp in the result columns.
If a `Timespan` instance, constrain the results to those whose
Expand All @@ -361,9 +365,8 @@ def select(
Returns
-------
query : `SimpleQuery`
A struct containing the SQLAlchemy object that representing a
simple ``SELECT`` query.
query : `sqlalchemy.sql.Selectable`
A SQLAlchemy object representing a simple ``SELECT`` query.
"""
raise NotImplementedError()

Expand Down

0 comments on commit cbf5f61

Please sign in to comment.