Skip to content

Commit

Permalink
Merge pull request #989 from lsst/tickets/DM-43671
Browse files Browse the repository at this point in the history
DM-43671: Implement atomic collection prepend

Added an atomic chained collection prepend method to Butler. It works by taking a row lock on the parent collection's row in the collections tables, which acts as a mutex for the modification of the collection_chain table. The existing setCollectionChain method now also uses this lock, so that it can interact safely with the new operation.
  • Loading branch information
dhirving committed Apr 4, 2024
2 parents 2059aa2 + ce81f8e commit c0af174
Show file tree
Hide file tree
Showing 20 changed files with 541 additions and 91 deletions.
3 changes: 3 additions & 0 deletions doc/changes/DM-43671.bugfix.md
@@ -0,0 +1,3 @@
The `flatten` flag for the `butler collection-chain` CLI command now works as documented: it only flattens the specified children instead of flattening the entire collection chain.

`registry.setCollectionChain` will no longer throw unique constraint violation exceptions when there are concurrent calls to this function. Instead, all calls will succeed and the last write will win. As a side-effect of this change, if calls to `setCollectionChain` occur within an explicit call to `Butler.transaction`, other processes attempting to modify the same chain will block until the transaction completes.
1 change: 1 addition & 0 deletions doc/changes/DM-43671.feature.md
@@ -0,0 +1 @@
Added a new method `Butler.prepend_collection_chain`. This allows you to insert collections at the beginning of a chain. It is an "atomic" operation that can be safely used concurrently from multiple processes.
35 changes: 35 additions & 0 deletions python/lsst/daf/butler/_butler.py
Expand Up @@ -1735,3 +1735,38 @@ def _clone(
``inferDefaults``, and default data ID.
"""
raise NotImplementedError()

@abstractmethod
def prepend_collection_chain(
self, parent_collection_name: str, child_collection_names: str | Iterable[str]
) -> None:
"""Add children to the beginning of a CHAINED collection.
If any of the children already existed in the chain, they will be moved
to the new position at the beginning of the chain.
Parameters
----------
parent_collection_name : `str`
The name of a CHAINED collection to which we will add new children.
child_collection_names : `Iterable` [ `str ` ] | `str`
A child collection name or list of child collection names to be
added to the parent.
Raises
------
MissingCollectionError
If any of the specified collections do not exist.
CollectionTypeError
If the parent collection is not a CHAINED collection.
CollectionCycleError
If this operation would create a collection cycle.
Notes
-----
If this function is called within a call to ``Butler.transaction``, it
will hold a lock that prevents other processes from modifying the
parent collection until the end of the transaction. Keep these
transactions short.
"""
raise NotImplementedError()
18 changes: 18 additions & 0 deletions python/lsst/daf/butler/_exceptions.py
Expand Up @@ -28,6 +28,8 @@
"""Specialized Butler exceptions."""
__all__ = (
"CalibrationLookupError",
"CollectionCycleError",
"CollectionTypeError",
"DatasetNotFoundError",
"DimensionNameError",
"ButlerUserError",
Expand Down Expand Up @@ -79,6 +81,20 @@ class CalibrationLookupError(LookupError, ButlerUserError):
error_type = "calibration_lookup"


class CollectionCycleError(ValueError, ButlerUserError):
"""Raised when an operation would cause a chained collection to be a child
of itself.
"""

error_type = "collection_cycle"


class CollectionTypeError(CollectionError, ButlerUserError):
"""Exception raised when type of a collection is incorrect."""

error_type = "collection_type"


class DatasetNotFoundError(LookupError, ButlerUserError):
"""The requested dataset could not be found."""

Expand Down Expand Up @@ -158,6 +174,8 @@ class UnknownButlerUserError(ButlerUserError):

_USER_ERROR_TYPES: tuple[type[ButlerUserError], ...] = (
CalibrationLookupError,
CollectionCycleError,
CollectionTypeError,
DimensionNameError,
DimensionValueError,
DatasetNotFoundError,
Expand Down
8 changes: 8 additions & 0 deletions python/lsst/daf/butler/direct_butler.py
Expand Up @@ -48,6 +48,7 @@

from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils.introspection import get_class_of
from lsst.utils.iteration import ensure_iterable
from lsst.utils.logging import VERBOSE, getLogger
from sqlalchemy.exc import IntegrityError

Expand Down Expand Up @@ -2141,6 +2142,13 @@ def _preload_cache(self) -> None:
"""Immediately load caches that are used for common operations."""
self._registry.preload_cache()

def prepend_collection_chain(
self, parent_collection_name: str, child_collection_names: str | Iterable[str]
) -> None:
return self._registry._managers.collections.prepend_collection_chain(
parent_collection_name, list(ensure_iterable(child_collection_names))
)

_config: ButlerConfig
"""Configuration for this Butler instance."""

Expand Down
7 changes: 6 additions & 1 deletion python/lsst/daf/butler/registry/__init__.py
Expand Up @@ -27,7 +27,12 @@

# Re-export some top-level exception types for backwards compatibility -- these
# used to be part of registry.
from .._exceptions import DimensionNameError, MissingCollectionError, MissingDatasetTypeError
from .._exceptions import (
CollectionTypeError,
DimensionNameError,
MissingCollectionError,
MissingDatasetTypeError,
)
from .._exceptions_legacy import CollectionError, DataIdError, DatasetTypeError, RegistryError

# Registry imports.
Expand Down
5 changes: 0 additions & 5 deletions python/lsst/daf/butler/registry/_exceptions.py
Expand Up @@ -29,7 +29,6 @@
__all__ = (
"ArgumentError",
"CollectionExpressionError",
"CollectionTypeError",
"ConflictingDefinitionError",
"DataIdValueError",
"DatasetTypeExpressionError",
Expand Down Expand Up @@ -64,10 +63,6 @@ class InconsistentDataIdError(DataIdError):
"""


class CollectionTypeError(CollectionError):
"""Exception raised when type of a collection is incorrect."""


class CollectionExpressionError(CollectionError):
"""Exception raised for an incorrect collection expression."""

Expand Down
196 changes: 167 additions & 29 deletions python/lsst/daf/butler/registry/collections/_base.py
Expand Up @@ -30,15 +30,13 @@

__all__ = ()

import itertools
from abc import abstractmethod
from collections import namedtuple
from collections.abc import Iterable, Iterator, Set
from typing import TYPE_CHECKING, Any, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar, cast

import sqlalchemy

from ..._exceptions import MissingCollectionError
from ..._exceptions import CollectionCycleError, CollectionTypeError, MissingCollectionError
from ...timespan_database_representation import TimespanDatabaseRepresentation
from .._collection_type import CollectionType
from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple
Expand Down Expand Up @@ -77,7 +75,13 @@ def _makeCollectionForeignKey(
return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs)


CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"])
_T = TypeVar("_T")


class CollectionTablesTuple(NamedTuple, Generic[_T]):
collection: _T
run: _T
collection_chain: _T


def makeRunTableSpec(
Expand Down Expand Up @@ -188,7 +192,7 @@ class DefaultCollectionManager(CollectionManager[K]):
def __init__(
self,
db: Database,
tables: CollectionTablesTuple,
tables: CollectionTablesTuple[sqlalchemy.Table],
collectionIdName: str,
*,
caching_context: CachingContext,
Expand Down Expand Up @@ -407,36 +411,170 @@ def update_chain(
self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False
) -> ChainedCollectionRecord[K]:
# Docstring inherited from CollectionManager.
children_as_wildcard = CollectionWildcard.from_names(children)
for record in self.resolve_wildcard(
children_as_wildcard,
flatten_chains=True,
include_chains=True,
collection_types={CollectionType.CHAINED},
):
if record == chain:
raise ValueError(f"Cycle in collection chaining when defining '{chain.name}'.")
children = list(children)
self._sanity_check_collection_cycles(chain.name, children)

if flatten:
children = tuple(
record.name for record in self.resolve_wildcard(children_as_wildcard, flatten_chains=True)
record.name
for record in self.resolve_wildcard(
CollectionWildcard.from_names(children), flatten_chains=True
)
)

rows = []
position = itertools.count()
names = []
for child in self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False):
rows.append(
{
"parent": chain.key,
"child": child.key,
"position": next(position),
}
)
names.append(child.name)
child_records = self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False)
names = [child.name for child in child_records]
with self._db.transaction():
self._find_and_lock_collection_chain(chain.name)
self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key})
self._db.insert(self._tables.collection_chain, *rows)
self._block_for_concurrency_test()
self._insert_collection_chain_rows(chain.key, 0, [child.key for child in child_records])

record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names))
self._addCachedRecord(record)
return record

def _sanity_check_collection_cycles(
self, parent_collection_name: str, child_collection_names: list[str]
) -> None:
"""Raise an exception if any of the collections in the ``child_names``
list have ``parent_name`` as a child, creating a collection cycle.
This is only a sanity check, and does not guarantee that no collection
cycles are possible. Concurrent updates might allow collection cycles
to be inserted.
"""
for record in self.resolve_wildcard(
CollectionWildcard.from_names(child_collection_names),
flatten_chains=True,
include_chains=True,
collection_types={CollectionType.CHAINED},
):
if record.name == parent_collection_name:
raise CollectionCycleError(
f"Cycle in collection chaining when defining '{parent_collection_name}'."
)

def _insert_collection_chain_rows(
self,
parent_key: K,
starting_position: int,
child_keys: list[K],
) -> None:
rows = [
{
"parent": parent_key,
"child": child,
"position": position,
}
for position, child in enumerate(child_keys, starting_position)
]
self._db.insert(self._tables.collection_chain, *rows)

def _remove_collection_chain_rows(
self,
parent_key: K,
child_keys: list[K],
) -> None:
table = self._tables.collection_chain
where = sqlalchemy.and_(table.c.parent == parent_key, table.c.child.in_(child_keys))
self._db.deleteWhere(table, where)

def prepend_collection_chain(
self, parent_collection_name: str, child_collection_names: list[str]
) -> None:
if self._caching_context.is_enabled:
# Avoid having cache-maintenance code around that is unlikely to
# ever be used.
raise RuntimeError("Chained collection modification not permitted with active caching context.")

self._sanity_check_collection_cycles(parent_collection_name, child_collection_names)

child_records = self.resolve_wildcard(
CollectionWildcard.from_names(child_collection_names), flatten_chains=False
)
child_keys = [child.key for child in child_records]

with self._db.transaction():
parent_key = self._find_and_lock_collection_chain(parent_collection_name)
self._remove_collection_chain_rows(parent_key, child_keys)
starting_position = self._find_lowest_position_in_collection_chain(parent_key) - len(child_keys)
self._block_for_concurrency_test()
self._insert_collection_chain_rows(parent_key, starting_position, child_keys)

def _find_lowest_position_in_collection_chain(self, chain_key: K) -> int:
"""Return the lowest-numbered position in a collection chain, or 0 if
the chain is empty.
"""
table = self._tables.collection_chain
query = sqlalchemy.select(sqlalchemy.func.min(table.c.position)).where(table.c.parent == chain_key)
with self._db.query(query) as cursor:
lowest_existing_position = cursor.scalar()

if lowest_existing_position is None:
return 0

return lowest_existing_position

def _find_and_lock_collection_chain(self, collection_name: str) -> K:
"""
Take a row lock on the specified collection's row in the collections
table, and return the collection's primary key.
This lock is used to synchronize updates to collection chains.
The locking strategy requires cooperation from everything modifying the
collection chain table -- all operations that modify collection chains
must obtain this lock first. The database will NOT automatically
prevent modification of tables based on this lock. The only guarantee
is that only one caller will be allowed to hold this lock for a given
collection at a time. Concurrent calls will block until the caller
holding the lock has completed its transaction.
Parameters
----------
collection_name : `str`
Name of the collection whose chain is being modified.
Returns
-------
id : ``K``
The primary key for the given collection.
Raises
------
MissingCollectionError
If the specified collection is not in the database table.
CollectionTypeError
If the specified collection is not a chained collection.
"""
assert self._db.isInTransaction(), (
"Row locks are only held until the end of the current transaction,"
" so it makes no sense to take a lock outside a transaction."
)
assert self._db.isWriteable(), "Collection row locks are only useful for write operations."

query = self._select_pkey_by_name(collection_name).with_for_update()
with self._db.query(query) as cursor:
rows = cursor.all()

if len(rows) == 0:
raise MissingCollectionError(
f"Parent collection {collection_name} not found when updating collection chain."
)
assert len(rows) == 1, "There should only be one entry for each collection in collection table."
r = rows[0]._mapping
if r["type"] != CollectionType.CHAINED:
raise CollectionTypeError(f"Parent collection {collection_name} is not a chained collection.")
return r["key"]

@abstractmethod
def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select:
"""Return a SQLAlchemy select statement that will return columns from
the one row in the ``collection` table matching the given name. The
select statement includes two columns:
- ``key`` : the primary key for the collection
- ``type`` : the collection type
"""
raise NotImplementedError()
10 changes: 9 additions & 1 deletion python/lsst/daf/butler/registry/collections/nameKey.py
Expand Up @@ -63,7 +63,9 @@
_LOG = logging.getLogger(__name__)


def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple:
def _makeTableSpecs(
TimespanReprClass: type[TimespanDatabaseRepresentation],
) -> CollectionTablesTuple[ddl.TableSpec]:
return CollectionTablesTuple(
collection=ddl.TableSpec(
fields=[
Expand Down Expand Up @@ -283,6 +285,12 @@ def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> li

return records

def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select:
table = self._tables.collection
return sqlalchemy.select(table.c.name.label("key"), table.c.type).where(
table.c.name == collection_name
)

@classmethod
def currentVersions(cls) -> list[VersionTuple]:
# Docstring inherited from VersionedExtension.
Expand Down

0 comments on commit c0af174

Please sign in to comment.