Skip to content

Commit

Permalink
Add protection to datasets manager import method (DM-31287)
Browse files Browse the repository at this point in the history
Dataset storage manager now has an additional protection against
inconsistent dataset definitions when importing datasets that use UUID
for dataset ID.
  • Loading branch information
andy-slac committed Oct 8, 2021
1 parent 27e3c37 commit d27114f
Show file tree
Hide file tree
Showing 5 changed files with 369 additions and 53 deletions.
221 changes: 183 additions & 38 deletions python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
Expand All @@ -27,11 +26,13 @@
DatasetType,
SimpleQuery,
Timespan,
ddl
)
from lsst.daf.butler.registry import ConflictingDefinitionError, UnsupportedIdGeneratorError
from lsst.daf.butler.registry.interfaces import DatasetRecordStorage, DatasetIdGenEnum

from ...summaries import GovernorDimensionRestriction
from .tables import makeTagTableSpec

if TYPE_CHECKING:
from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord
Expand Down Expand Up @@ -587,6 +588,9 @@ def insert(self, run: RunRecord, dataIds: Iterable[DataCoordinate],
idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE) -> Iterator[DatasetRef]:
# Docstring inherited from DatasetRecordStorage.

# Remember any governor dimension values we see.
governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe)

# Iterate over data IDs, transforming a possibly-single-pass iterable
# into a list.
dataIdList = []
Expand All @@ -598,46 +602,11 @@ def insert(self, run: RunRecord, dataIds: Iterable[DataCoordinate],
"dataset_type_id": self._dataset_type_id,
self._runKeyColumn: run.key,
})

yield from self._insert(run, dataIdList, rows, self._db.insert)

def import_(self, run: RunRecord, datasets: Iterable[DatasetRef],
idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
reuseIds: bool = False) -> Iterator[DatasetRef]:
# Docstring inherited from DatasetRecordStorage.

# Iterate over data IDs, transforming a possibly-single-pass iterable
# into a list.
dataIdList = []
rows = []
for dataset in datasets:
dataIdList.append(dataset.dataId)
# Ignore unknown ID types, normally all IDs have the same type but
# this code supports mixed types or missing IDs.
datasetId = dataset.id if isinstance(dataset.id, uuid.UUID) else None
if datasetId is None:
datasetId = self._makeDatasetId(run, dataset.dataId, idGenerationMode)
rows.append({
"id": datasetId,
"dataset_type_id": self._dataset_type_id,
self._runKeyColumn: run.key,
})

yield from self._insert(run, dataIdList, rows, self._db.ensure)

def _insert(self, run: RunRecord, dataIdList: List[DataCoordinate],
rows: List[Dict], insertMethod: Callable) -> Iterator[DatasetRef]:
"""Common part of implementation of `insert` and `import_` methods.
"""

# Remember any governor dimension values we see.
governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe)
for dataId in dataIdList:
governorValues.update_extract(dataId)

with self._db.transaction():
# Insert into the static dataset table.
insertMethod(self._static.dataset, *rows)
self._db.insert(self._static.dataset, *rows)
# Update the summary tables for this collection in case this is the
# first time this dataset type or these governor values will be
# inserted there.
Expand All @@ -653,7 +622,8 @@ def _insert(self, run: RunRecord, dataIdList: List[DataCoordinate],
for dataId, row in zip(dataIdList, rows)
]
# Insert those rows into the tags table.
insertMethod(self._tags, *tagsRows)
self._db.insert(self._tags, *tagsRows)

for dataId, row in zip(dataIdList, rows):
yield DatasetRef(
datasetType=self.datasetType,
Expand All @@ -662,6 +632,181 @@ def _insert(self, run: RunRecord, dataIdList: List[DataCoordinate],
run=run.name,
)

def import_(self, run: RunRecord, datasets: Iterable[DatasetRef],
idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
reuseIds: bool = False) -> Iterator[DatasetRef]:
# Docstring inherited from DatasetRecordStorage.

# Remember any governor dimension values we see.
governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe)

# Iterate over data IDs, transforming a possibly-single-pass iterable
# into a list.
dataIds = {}
for dataset in datasets:
# Ignore unknown ID types, normally all IDs have the same type but
# this code supports mixed types or missing IDs.
datasetId = dataset.id if isinstance(dataset.id, uuid.UUID) else None
if datasetId is None:
datasetId = self._makeDatasetId(run, dataset.dataId, idGenerationMode)
dataIds[datasetId] = dataset.dataId
governorValues.update_extract(dataset.dataId)

with self._db.session() as session:

# insert all new rows into a temporary table
tableSpec = makeTagTableSpec(self.datasetType, type(self._collections),
ddl.GUID, constraints=False)
tmp_tags = session.makeTemporaryTable(tableSpec)

collFkName = self._collections.getCollectionForeignKeyName()
protoTagsRow = {
"dataset_type_id": self._dataset_type_id,
collFkName: run.key,
}
tmpRows = [dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
for dataset_id, dataId in dataIds.items()]

with self._db.transaction():

# store all incoming data in a temporary table
self._db.insert(tmp_tags, *tmpRows)

# There are some checks that we want to make for consistency
# of the new datasets with existing ones.
self._validateImport(tmp_tags, run)

# Before we merge temporary table into dataset/tags we need to
# drop datasets which are already there (and do not conflict).
self._db.deleteWhere(tmp_tags, tmp_tags.columns.dataset_id.in_(
sqlalchemy.sql.select(self._static.dataset.columns.id)
))

# Copy it into dataset table, need to re-label some columns.
self._db.insert(self._static.dataset, select=sqlalchemy.sql.select(
tmp_tags.columns.dataset_id.label("id"),
tmp_tags.columns.dataset_type_id,
tmp_tags.columns[collFkName].label(self._runKeyColumn)
))

# Update the summary tables for this collection in case this
# is the first time this dataset type or these governor values
# will be inserted there.
self._summaries.update(run, self.datasetType, self._dataset_type_id, governorValues)

# Copy it into tags table.
self._db.insert(self._tags, select=tmp_tags.select())

# Return refs in the same order as in the input list.
for dataset_id, dataId in dataIds.items():
yield DatasetRef(
datasetType=self.datasetType,
id=dataset_id,
dataId=dataId,
run=run.name,
)

def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None:
"""Validate imported refs against existing datasets.
Parameters
----------
tmp_tags : `sqlalchemy.schema.Table`
Temporary table with new datasets and the same schema as tags
table.
run : `RunRecord`
The record object describing the `~CollectionType.RUN` collection.
Raises
------
ConflictingDefinitionError
Raise if new datasets conflict with existing ones.
"""
dataset = self._static.dataset
tags = self._tags
collFkName = self._collections.getCollectionForeignKeyName()

# Check that existing datasets have the same dataset type and
# run.
query = sqlalchemy.sql.select(
dataset.columns.id.label("dataset_id"),
dataset.columns.dataset_type_id.label("dataset_type_id"),
tmp_tags.columns.dataset_type_id.label("new dataset_type_id"),
dataset.columns[self._runKeyColumn].label("run"),
tmp_tags.columns[collFkName].label("new run")
).select_from(
dataset.join(
tmp_tags,
dataset.columns.id == tmp_tags.columns.dataset_id
)
).where(
sqlalchemy.sql.or_(
dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName]
)
)
result = self._db.query(query)
if (row := result.first()) is not None:
# Only include the first one in the exception message
raise ConflictingDefinitionError(
f"Existing dataset type or run do not match new dataset: {row._asdict()}"
)

# Check that matching dataset in tags table has the same DataId.
query = sqlalchemy.sql.select(
tags.columns.dataset_id,
tags.columns.dataset_type_id.label("type_id"),
tmp_tags.columns.dataset_type_id.label("new type_id"),
*[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
*[tmp_tags.columns[dim].label(f"new {dim}")
for dim in self.datasetType.dimensions.required.names],
).select_from(
tags.join(
tmp_tags,
tags.columns.dataset_id == tmp_tags.columns.dataset_id
)
).where(
sqlalchemy.sql.or_(
tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
*[tags.columns[dim] != tmp_tags.columns[dim]
for dim in self.datasetType.dimensions.required.names]
)
)
result = self._db.query(query)
if (row := result.first()) is not None:
# Only include the first one in the exception message
raise ConflictingDefinitionError(
f"Existing dataset type or dataId do not match new dataset: {row._asdict()}"
)

# Check that matching run+dataId have the same dataset ID.
query = sqlalchemy.sql.select(
tags.columns.dataset_type_id.label("dataset_type_id"),
*[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
tags.columns.dataset_id,
tmp_tags.columns.dataset_id.label("new dataset_id"),
tags.columns[collFkName],
tmp_tags.columns[collFkName].label(f"new {collFkName}")
).select_from(
tags.join(
tmp_tags,
sqlalchemy.sql.and_(
tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id,
tags.columns[collFkName] == tmp_tags.columns[collFkName],
*[tags.columns[dim] == tmp_tags.columns[dim]
for dim in self.datasetType.dimensions.required.names]
)
)
).where(
tags.columns.dataset_id != tmp_tags.columns.dataset_id
)
result = self._db.query(query)
if (row := result.first()) is not None:
# only include the first one in the exception message
raise ConflictingDefinitionError(
f"Existing dataset type and dataId does not match new dataset: {row._asdict()}"
)

def _makeDatasetId(self, run: RunRecord, dataId: DataCoordinate,
idGenerationMode: DatasetIdGenEnum) -> uuid.UUID:
"""Generate dataset ID for a dataset.
Expand Down
36 changes: 21 additions & 15 deletions python/lsst/daf/butler/registry/datasets/byDimensions/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def makeCalibTableName(datasetType: DatasetType, dimensionsKey: int) -> str:


def makeTagTableSpec(datasetType: DatasetType, collections: Type[CollectionManager],
dtype: type) -> ddl.TableSpec:
dtype: type, *, constraints: bool = True) -> ddl.TableSpec:
"""Construct the specification for a dynamic (DatasetType-dependent) tag
table used by the classes in this package.
Expand All @@ -297,9 +297,11 @@ def makeTagTableSpec(datasetType: DatasetType, collections: Type[CollectionManag
collections : `type` [ `CollectionManager` ]
`CollectionManager` subclass that can be used to construct foreign keys
to the run and/or collection tables.
dtype: `type`
dtype : `type`
Type of the FK column, same as the column type of the PK column of
a referenced table (``dataset.id``).
constraints : `bool`, optional
If `False` (`True` is default), do not define foreign key constraints.
Returns
-------
Expand All @@ -314,35 +316,39 @@ def makeTagTableSpec(datasetType: DatasetType, collections: Type[CollectionManag
# in the main monolithic dataset table, but we need it here for an
# important unique constraint.
ddl.FieldSpec("dataset_type_id", dtype=sqlalchemy.BigInteger, nullable=False),
],
foreignKeys=[
ddl.ForeignKeySpec("dataset_type", source=("dataset_type_id",), target=("id",)),
]
)
if constraints:
tableSpec.foreignKeys.append(
ddl.ForeignKeySpec("dataset_type", source=("dataset_type_id",), target=("id",))
)
# We'll also have a unique constraint on dataset type, collection, and data
# ID. We only include the required part of the data ID, as that's
# sufficient and saves us from worrying about nulls in the constraint.
constraint = ["dataset_type_id"]
# Add foreign key fields to dataset table (part of the primary key)
addDatasetForeignKey(tableSpec, dtype, primaryKey=True, onDelete="CASCADE")
addDatasetForeignKey(tableSpec, dtype, primaryKey=True, onDelete="CASCADE", constraint=constraints)
# Add foreign key fields to collection table (part of the primary key and
# the data ID unique constraint).
collectionFieldSpec = collections.addCollectionForeignKey(tableSpec, primaryKey=True, onDelete="CASCADE")
collectionFieldSpec = collections.addCollectionForeignKey(tableSpec, primaryKey=True, onDelete="CASCADE",
constraint=constraints)
constraint.append(collectionFieldSpec.name)
# Add foreign key constraint to the collection_summary_dataset_type table.
tableSpec.foreignKeys.append(
ddl.ForeignKeySpec(
"collection_summary_dataset_type",
source=(collectionFieldSpec.name, "dataset_type_id"),
target=(collectionFieldSpec.name, "dataset_type_id"),
if constraints:
tableSpec.foreignKeys.append(
ddl.ForeignKeySpec(
"collection_summary_dataset_type",
source=(collectionFieldSpec.name, "dataset_type_id"),
target=(collectionFieldSpec.name, "dataset_type_id"),
)
)
)
for dimension in datasetType.dimensions.required:
fieldSpec = addDimensionForeignKey(tableSpec, dimension=dimension, nullable=False, primaryKey=False)
fieldSpec = addDimensionForeignKey(tableSpec, dimension=dimension, nullable=False, primaryKey=False,
constraint=constraints)
constraint.append(fieldSpec.name)
# If this is a governor dimension, add a foreign key constraint to the
# collection_summary_<dimension> table.
if isinstance(dimension, GovernorDimension):
if isinstance(dimension, GovernorDimension) and constraints:
tableSpec.foreignKeys.append(
ddl.ForeignKeySpec(
f"collection_summary_{dimension.name}",
Expand Down
36 changes: 36 additions & 0 deletions python/lsst/daf/butler/registry/interfaces/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,42 @@ def delete(self, table: sqlalchemy.schema.Table, columns: Iterable[str], *rows:
rowcount += connection.execute(newsql).rowcount
return rowcount

def deleteWhere(self, table: sqlalchemy.schema.Table, where: sqlalchemy.sql.ClauseElement) -> int:
"""Delete rows from a table with pre-constructed WHERE clause.
Parameters
----------
table : `sqlalchemy.schema.Table`
Table that rows should be deleted from.
where: `sqlalchemy.sql.ClauseElement`
The names of columns that will be used to constrain the rows to
be deleted; these will be combined via ``AND`` to form the
``WHERE`` clause of the delete query.
Returns
-------
count : `int`
Number of rows deleted.
Raises
------
ReadOnlyDatabaseError
Raised if `isWriteable` returns `False` when this method is called.
Notes
-----
May be used inside transaction contexts, so implementations may not
perform operations that interrupt transactions.
The default implementation should be sufficient for most derived
classes.
"""
self.assertTableWriteable(table, f"Cannot delete from read-only table {table}.")

sql = table.delete().where(where)
with self._connection() as connection:
return connection.execute(sql).rowcount

def update(self, table: sqlalchemy.schema.Table, where: Dict[str, str], *rows: dict) -> int:
"""Update one or more rows in a table.
Expand Down

0 comments on commit d27114f

Please sign in to comment.