Skip to content

Commit

Permalink
Remove dataset type caching from datasets manager.
Browse files Browse the repository at this point in the history
  • Loading branch information
andy-slac committed Nov 9, 2023
1 parent b1d4cb8 commit f2e1c09
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 129 deletions.
229 changes: 118 additions & 111 deletions python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py
Expand Up @@ -4,6 +4,7 @@

__all__ = ("ByDimensionsDatasetRecordStorageManagerUUID",)

import dataclasses
import logging
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -55,6 +56,16 @@ class MissingDatabaseTableError(RuntimeError):
"""Exception raised when a table is not found in a database."""


@dataclasses.dataclass
class _DatasetTypeRecord:
"""Contents of a single dataset type record."""

dataset_type: DatasetType
dataset_type_id: int
tag_table_name: str
calib_table_name: str | None


class _ExistingTableFactory:
"""Factory for `sqlalchemy.schema.Table` instances that returns already
existing table instance.
Expand Down Expand Up @@ -139,8 +150,6 @@ def __init__(
self._dimensions = dimensions
self._static = static
self._summaries = summaries
self._byName: dict[str, ByDimensionsDatasetRecordStorage] = {}
self._byId: dict[int, ByDimensionsDatasetRecordStorage] = {}

@classmethod
def initialize(
Expand All @@ -162,6 +171,7 @@ def initialize(
context,
collections=collections,
dimensions=dimensions,
dataset_type_table=static.dataset_type,
)
return cls(
db=db,
Expand Down Expand Up @@ -236,44 +246,33 @@ def addDatasetForeignKey(

def refresh(self) -> None:
# Docstring inherited from DatasetRecordStorageManager.
byName: dict[str, ByDimensionsDatasetRecordStorage] = {}
byId: dict[int, ByDimensionsDatasetRecordStorage] = {}
c = self._static.dataset_type.columns
with self._db.query(self._static.dataset_type.select()) as sql_result:
sql_rows = sql_result.mappings().fetchall()
for row in sql_rows:
name = row[c.name]
dimensions = self._dimensions.loadDimensionGraph(row[c.dimensions_key])
calibTableName = row[c.calibration_association_table]
datasetType = DatasetType(
name, dimensions, row[c.storage_class], isCalibration=(calibTableName is not None)
)
tags_spec = makeTagTableSpec(datasetType, type(self._collections), self.getIdColumnType())
tags_table_factory = _SpecTableFactory(self._db, row[c.tag_association_table], tags_spec)
calibs_table_factory = None
if calibTableName is not None:
calibs_spec = makeCalibTableSpec(
datasetType,
type(self._collections),
self._db.getTimespanRepresentation(),
self.getIdColumnType(),
)
calibs_table_factory = _SpecTableFactory(self._db, calibTableName, calibs_spec)
storage = self._recordStorageType(
db=self._db,
datasetType=datasetType,
static=self._static,
summaries=self._summaries,
tags_table_factory=tags_table_factory,
calibs_table_factory=calibs_table_factory,
dataset_type_id=row["id"],
collections=self._collections,
use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai,
pass

def _make_storage(self, record: _DatasetTypeRecord) -> ByDimensionsDatasetRecordStorage:
"""Create storage instance for a dataset type record."""
tags_spec = makeTagTableSpec(record.dataset_type, type(self._collections), self.getIdColumnType())
tags_table_factory = _SpecTableFactory(self._db, record.tag_table_name, tags_spec)
calibs_table_factory = None
if record.calib_table_name is not None:
calibs_spec = makeCalibTableSpec(
record.dataset_type,
type(self._collections),
self._db.getTimespanRepresentation(),
self.getIdColumnType(),
)
byName[datasetType.name] = storage
byId[storage._dataset_type_id] = storage
self._byName = byName
self._byId = byId
calibs_table_factory = _SpecTableFactory(self._db, record.calib_table_name, calibs_spec)
storage = self._recordStorageType(
db=self._db,
datasetType=record.dataset_type,
static=self._static,
summaries=self._summaries,
tags_table_factory=tags_table_factory,
calibs_table_factory=calibs_table_factory,
dataset_type_id=record.dataset_type_id,
collections=self._collections,
use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai,
)
return storage

def remove(self, name: str) -> None:
# Docstring inherited from DatasetRecordStorageManager.
Expand All @@ -296,33 +295,28 @@ def remove(self, name: str) -> None:

def find(self, name: str) -> DatasetRecordStorage | None:
# Docstring inherited from DatasetRecordStorageManager.
return self._byName.get(name)
record = self._fetch_dataset_type_record(name)
return self._make_storage(record) if record is not None else None

def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool]:
def register(self, datasetType: DatasetType) -> bool:
# Docstring inherited from DatasetRecordStorageManager.
if datasetType.isComponent():
raise ValueError(
f"Component dataset types can not be stored in registry. Rejecting {datasetType.name}"
)
storage = self._byName.get(datasetType.name)
if storage is None:
record = self._fetch_dataset_type_record(datasetType.name)
if record is None:
dimensionsKey = self._dimensions.saveDimensionGraph(datasetType.dimensions)
tagTableName = makeTagTableName(datasetType, dimensionsKey)
calibTableName = (
makeCalibTableName(datasetType, dimensionsKey) if datasetType.isCalibration() else None
)
# The order is important here, we want to create tables first and
# only register them if this operation is successful. We cannot
# wrap it into a transaction because database class assumes that
# DDL is not transaction safe in general.
tags = self._db.ensureTableExists(
self._db.ensureTableExists(
tagTableName,
makeTagTableSpec(datasetType, type(self._collections), self.getIdColumnType()),
)
tags_table_factory = _ExistingTableFactory(tags)
calibs_table_factory = None
calibTableName = (
makeCalibTableName(datasetType, dimensionsKey) if datasetType.isCalibration() else None
)
if calibTableName is not None:
calibs = self._db.ensureTableExists(
self._db.ensureTableExists(
calibTableName,
makeCalibTableSpec(
datasetType,
Expand All @@ -331,8 +325,7 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool
self.getIdColumnType(),
),
)
calibs_table_factory = _ExistingTableFactory(calibs)
row, inserted = self._db.sync(
_, inserted = self._db.sync(
self._static.dataset_type,
keys={"name": datasetType.name},
compared={
Expand All @@ -347,28 +340,17 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool
},
returning=["id", "tag_association_table"],
)
assert row is not None
storage = self._recordStorageType(
db=self._db,
datasetType=datasetType,
static=self._static,
summaries=self._summaries,
tags_table_factory=tags_table_factory,
calibs_table_factory=calibs_table_factory,
dataset_type_id=row["id"],
collections=self._collections,
use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai,
)
self._byName[datasetType.name] = storage
self._byId[storage._dataset_type_id] = storage
else:
if datasetType != storage.datasetType:
if datasetType != record.dataset_type:
raise ConflictingDefinitionError(
f"Given dataset type {datasetType} is inconsistent "
f"with database definition {storage.datasetType}."
f"with database definition {record.dataset_type}."
)
inserted = False
return storage, bool(inserted)
# TODO: We return storage instance from this method, but the only
# client that uses this method ignores it. Maybe we should drop it
# and avoid making storage instance above.
return bool(inserted)

def resolve_wildcard(
self,
Expand Down Expand Up @@ -422,15 +404,13 @@ def resolve_wildcard(
raise TypeError(
"Universal wildcard '...' is not permitted for dataset types in this context."
)
for storage in self._byName.values():
result[storage.datasetType].add(None)
for datasetType in self._fetch_dataset_types():
result[datasetType].add(None)
if components:
try:
result[storage.datasetType].update(
storage.datasetType.storageClass.allComponents().keys()
)
result[datasetType].update(datasetType.storageClass.allComponents().keys())
if (
storage.datasetType.storageClass.allComponents()
datasetType.storageClass.allComponents()
and not already_warned
and components_deprecated
):
Expand All @@ -442,7 +422,7 @@ def resolve_wildcard(
already_warned = True
except KeyError as err:
_LOG.warning(
f"Could not load storage class {err} for {storage.datasetType.name}; "
f"Could not load storage class {err} for {datasetType.name}; "
"if it has components they will not be included in query results.",
)
elif wildcard.patterns:
Expand All @@ -454,29 +434,28 @@ def resolve_wildcard(
FutureWarning,
stacklevel=find_outside_stacklevel("lsst.daf.butler"),
)
for storage in self._byName.values():
if any(p.fullmatch(storage.datasetType.name) for p in wildcard.patterns):
result[storage.datasetType].add(None)
dataset_types = self._fetch_dataset_types()
for datasetType in dataset_types:
if any(p.fullmatch(datasetType.name) for p in wildcard.patterns):
result[datasetType].add(None)
if components is not False:
for storage in self._byName.values():
if components is None and storage.datasetType in result:
for datasetType in dataset_types:
if components is None and datasetType in result:
continue
try:
components_for_parent = storage.datasetType.storageClass.allComponents().keys()
components_for_parent = datasetType.storageClass.allComponents().keys()
except KeyError as err:
_LOG.warning(
f"Could not load storage class {err} for {storage.datasetType.name}; "
f"Could not load storage class {err} for {datasetType.name}; "
"if it has components they will not be included in query results."
)
continue
for component_name in components_for_parent:
if any(
p.fullmatch(
DatasetType.nameWithComponent(storage.datasetType.name, component_name)
)
p.fullmatch(DatasetType.nameWithComponent(datasetType.name, component_name))
for p in wildcard.patterns
):
result[storage.datasetType].add(component_name)
result[datasetType].add(component_name)
if not already_warned and components_deprecated:
warnings.warn(
deprecation_message,
Expand All @@ -492,49 +471,77 @@ def getDatasetRef(self, id: DatasetId) -> DatasetRef | None:
sqlalchemy.sql.select(
self._static.dataset.columns.dataset_type_id,
self._static.dataset.columns[self._collections.getRunForeignKeyName()],
*self._static.dataset_type.columns,
)
.select_from(self._static.dataset)
.join(self._static.dataset_type)
.where(self._static.dataset.columns.id == id)
)
with self._db.query(sql) as sql_result:
row = sql_result.mappings().fetchone()
if row is None:
return None
recordsForType = self._byId.get(row[self._static.dataset.columns.dataset_type_id])
if recordsForType is None:
self.refresh()
recordsForType = self._byId.get(row[self._static.dataset.columns.dataset_type_id])
assert recordsForType is not None, "Should be guaranteed by foreign key constraints."
storage = self._make_storage(self._record_from_row(row))
return DatasetRef(
recordsForType.datasetType,
dataId=recordsForType.getDataId(id=id),
storage.datasetType,
dataId=storage.getDataId(id=id),
id=id,
run=self._collections[row[self._collections.getRunForeignKeyName()]].name,
)

def _dataset_type_factory(self, dataset_type_id: int) -> DatasetType:
"""Return dataset type given its ID."""
return self._byId[dataset_type_id].datasetType
def _fetch_dataset_type_record(self, name: str) -> _DatasetTypeRecord | None:
"""Retrieve all dataset types defined in database.
Yields
------
dataset_types : `tuple` [`_DatasetTypeRecord`]
Information from a single database record.
"""
c = self._static.dataset_type.columns
stmt = self._static.dataset_type.select().where(c.name == name)
with self._db.query(stmt) as sql_result:
row = sql_result.mappings().one_or_none()
if row is None:
return None
else:
return self._record_from_row(row)

def _record_from_row(self, row: Mapping) -> _DatasetTypeRecord:
name = row["name"]
dimensions = self._dimensions.loadDimensionGraph(row["dimensions_key"])
calibTableName = row["calibration_association_table"]
datasetType = DatasetType(
name, dimensions, row["storage_class"], isCalibration=(calibTableName is not None)
)
return _DatasetTypeRecord(
dataset_type=datasetType,
dataset_type_id=row["id"],
tag_table_name=row["tag_association_table"],
calib_table_name=calibTableName,
)

def _dataset_type_from_row(self, row: Mapping) -> DatasetType:
return self._record_from_row(row).dataset_type

def _fetch_dataset_types(self) -> list[DatasetType]:
"""Fetch list of all defined dataset types."""
with self._db.query(self._static.dataset_type.select()) as sql_result:
sql_rows = sql_result.mappings().fetchall()
return [self._record_from_row(row).dataset_type for row in sql_rows]

def getCollectionSummary(self, collection: CollectionRecord) -> CollectionSummary:
# Docstring inherited from DatasetRecordStorageManager.
summaries = self._summaries.fetch_summaries([collection], None, self._dataset_type_factory)
summaries = self._summaries.fetch_summaries([collection], None, self._dataset_type_from_row)
return summaries[collection.key]

def fetch_summaries(
self, collections: Iterable[CollectionRecord], dataset_types: Iterable[DatasetType] | None = None
) -> Mapping[Any, CollectionSummary]:
# Docstring inherited from DatasetRecordStorageManager.
dataset_type_ids: list[int] | None = None
dataset_type_names: Iterable[str] | None = None
if dataset_types is not None:
dataset_type_ids = []
for dataset_type in dataset_types:
if dataset_type.isComponent():
dataset_type = dataset_type.makeCompositeDatasetType()
# Assume we know all possible names.
dataset_type_id = self._byName[dataset_type.name]._dataset_type_id
dataset_type_ids.append(dataset_type_id)
return self._summaries.fetch_summaries(collections, dataset_type_ids, self._dataset_type_factory)
dataset_type_names = set(dataset_type.name for dataset_type in dataset_types)
return self._summaries.fetch_summaries(collections, dataset_type_names, self._dataset_type_from_row)

_versions: list[VersionTuple]
"""Schema version for this class."""
Expand Down

0 comments on commit f2e1c09

Please sign in to comment.