Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 107 additions & 50 deletions python/activator/middleware_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,16 +376,20 @@ def prep_butler(self) -> None:
self.central_butler.registry.refresh()
self.butler.registry.refresh()

with tempfile.NamedTemporaryFile(mode="w+b", suffix=".yaml") as export_file:
with self.central_butler.export(filename=export_file.name, format="yaml") as export:
self._export_refcats(export, center, radius)
self._export_skymap_and_templates(export, center, detector, wcs, self.visit.filters)
self._export_calibs(export, self.visit.detector, self.visit.filters)
self._export_collections(export, self.instrument.makeUmbrellaCollectionName())

self.butler.import_(filename=export_file.name,
directory=self.central_butler.datastore.root,
transfer="copy")
refcat_datasets = list(self._export_refcats(center, radius))
template_datasets = list(self._export_skymap_and_templates(center, detector, wcs, self.visit.filters))
calib_datasets = list(self._export_calibs(self.visit.detector, self.visit.filters))
self.butler.transfer_from(self.central_butler,
refcat_datasets + template_datasets + calib_datasets,
transfer="copy",
skip_missing=True,
register_dataset_types=True,
transfer_dimensions=True,
)

self._export_collections(self._get_template_collection())
self._export_collections(self.instrument.makeUmbrellaCollectionName())
self._export_calib_associations(self.instrument.makeCalibrationCollectionName(), calib_datasets)

# Temporary workarounds until we have a prompt-processing default top-level collection
# in shared repos, and raw collection in dev repo, and then we can organize collections
Expand All @@ -401,17 +405,20 @@ def _get_template_collection(self):
"""
return self.instrument.makeCollectionName("templates")

def _export_refcats(self, export, center, radius):
"""Export the refcats for this visit from the central butler.
def _export_refcats(self, center, radius):
"""Identify the refcats to export from the central butler.

Parameters
----------
export : `Iterator[RepoExportContext]`
Export context manager.
center : `lsst.geom.SpherePoint`
Center of the region to find refcat shards in.
radius : `lst.geom.Angle`
Radius to search for refcat shards in.

Returns
-------
refcats : iterable [`DatasetRef`]
The refcats to be exported, after any filtering.
"""
indexer = HtmIndexer(depth=7)
shard_ids, _ = indexer.getShardIds(center, radius+self.padding)
Expand All @@ -430,16 +437,13 @@ def _export_refcats(self, export, center, radius):
where=htm_where,
findFirst=True))
_log.debug("Found %d new refcat datasets.", len(refcats))
export.saveDatasets(refcats)
return refcats

def _export_skymap_and_templates(self, export, center, detector, wcs, filter):
"""Export the skymap and templates for this visit from the central
butler.
def _export_skymap_and_templates(self, center, detector, wcs, filter):
"""Identify the skymap and templates to export from the central butler.

Parameters
----------
export : `Iterator[RepoExportContext]`
Export context manager.
center : `lsst.geom.SpherePoint`
Center of the region to load the skyamp tract/patches for.
detector : `lsst.afw.cameraGeom.Detector`
Expand All @@ -448,6 +452,11 @@ def _export_skymap_and_templates(self, export, center, detector, wcs, filter):
Rough WCS for the upcoming visit, to help finding patches.
filter : `str`
Physical filter for which to export templates.

Returns
-------
skymapTemplates : iterable [`DatasetRef`]
The datasets to be exported, after any filtering.
"""
# TODO: This exports the whole skymap, but we want to only export the
# subset of the skymap that covers this data.
Expand All @@ -458,7 +467,6 @@ def _export_skymap_and_templates(self, export, center, detector, wcs, filter):
collections=self._COLLECTION_SKYMAP,
findFirst=True))
_log.debug("Found %d new skymap datasets.", len(skymaps))
export.saveDatasets(skymaps)
# Getting only one tract should be safe: we're getting the
# tract closest to this detector, so we should be well within
# the tract bbox.
Expand All @@ -483,22 +491,25 @@ def _export_skymap_and_templates(self, export, center, detector, wcs, filter):
findFirst=True))
except _MissingDatasetError as err:
_log.error(err)
templates = set()
else:
_log.debug("Found %d new template datasets.", len(templates))
export.saveDatasets(templates)
self._export_collections(export, self._get_template_collection())
return skymaps | templates

def _export_calibs(self, export, detector_id, filter):
"""Export the calibs for this visit from the central butler.
def _export_calibs(self, detector_id, filter):
"""Identify the calibs to export from the central butler.

Parameters
----------
export : `Iterator[RepoExportContext]`
Export context manager.
detector_id : `int`
Identifier of the detector to load calibs for.
filter : `str`
Physical filter name of the upcoming visit.

Returns
-------
calibs : iterable [`DatasetRef`]
The calibs to be exported, after any filtering.
"""
# TODO: we can't filter by validity range because it's not
# supported in queryDatasets yet.
Expand All @@ -520,27 +531,66 @@ def _export_calibs(self, export, detector_id, filter):
_log.debug("Found %d new calib datasets of type '%s'.", n_datasets, dataset_type)
else:
_log.debug("Found 0 new calib datasets.")
export.saveDatasets(
calibs,
elements=[]) # elements=[] means do not export dimension records
return calibs

def _export_collections(self, export, collection):
def _export_collections(self, collection):
"""Export the collection and all its children.

This preserves the collection structure even if some child collections
do not have data. Exporting a collection does not export its datasets.

Parameters
----------
export : `Iterator[RepoExportContext]`
Export context manager.
collection : `str`
The collection to be exported. It is usually a CHAINED collection
and can have many children.
"""
for child in self.central_butler.registry.queryCollections(
collection, flattenChains=True, includeChains=True):
export.saveCollection(child)
src = self.central_butler.registry
dest = self.butler.registry

# Store collection chains after all children guaranteed to exist
chains = {}
for child in src.queryCollections(collection, flattenChains=True, includeChains=True):
if src.getCollectionType(child) == CollectionType.CHAINED:
chains[child] = src.getCollectionChain(child)
dest.registerCollection(child,
src.getCollectionType(child),
src.getCollectionDocumentation(child))
for chain, children in chains.items():
dest.setCollectionChain(chain, children)

def _export_calib_associations(self, calib_collection, datasets):
"""Export the associations between a set of datasets and a
calibration collection.

Parameters
----------
calib_collection : `str`
The calibration collection, or a chain thereof, containing the
associations. The collection and any children must exist in both
the central and local repos.
datasets : iterable [`lsst.daf.butler.DatasetRef']
The calib datasets whose associations must be exported. Must be
certified in ``calib_collection`` in the central repo, and must
exist in the local repo.
"""
dataset_types = {ref.datasetType for ref in datasets}
associations = {}
for dataset_type in dataset_types:
associations.update(
(a.ref, a) for a in self.central_butler.registry.queryDatasetAssociations(
dataset_type,
calib_collection,
collectionTypes={CollectionType.CALIBRATION},
flattenChains=True
)
)
for dataset in datasets:
association = associations[dataset]
# certify is designed to work on groups of datasets; in practice,
# the total number of calibs (~1 of each type) is small enough that
# grouping by timespan isn't worth it.
self.butler.registry.certify(association.collection, [dataset], association.timespan)

@staticmethod
def _count_by_type(refs):
Expand Down Expand Up @@ -857,7 +907,8 @@ def _export_subset(self, exposure_ids: set[int],
# Since AP processing is strictly visit-detector, these three
# dimensions should suffice.
# DO NOT assume that visit == exposure!
where=f"exposure in ({', '.join(str(x) for x in exposure_ids)})",
where="exposure in (exposure_ids)",
bind={"exposure_ids": exposure_ids},
instrument=self.instrument.getName(),
detector=self.visit.detector,
))
Expand All @@ -868,19 +919,25 @@ def _export_subset(self, exposure_ids: set[int],
# TODO: get a proper synchronization API for Butler
self.central_butler.registry.refresh()

with tempfile.NamedTemporaryFile(mode="w+b", suffix=".yaml") as export_file:
# MUST NOT export governor dimensions, as this causes deadlocks in
# central registry. Can omit most other dimensions (all dimensions,
# after DM-36051) to avoid locks or redundant work.
# TODO: saveDatasets(elements={"exposure", "visit"}) doesn't work.
# Use import(skip_dimensions) until DM-36062 is fixed.
with self.butler.export(filename=export_file.name) as export:
export.saveDatasets(datasets)
self.central_butler.import_(filename=export_file.name,
directory=self.butler.datastore.root,
skip_dimensions={"instrument", "detector",
"skymap", "tract", "patch"},
transfer="copy")
# Transferring governor dimensions in parallel can cause deadlocks in
# central registry. We need to transfer our exposure/visit dimensions,
# so handle those manually.
for dimension in ["exposure",
"visit",
"visit_definition",
"visit_detector_region",
"visit_system",
"visit_system_membership",
]:
for record in self.butler.registry.queryDimensionRecords(
dimension,
where="exposure in (exposure_ids)",
bind={"exposure_ids": exposure_ids},
instrument=self.instrument.getName(),
detector=self.visit.detector,
):
self.central_butler.registry.syncDimensionData(dimension, record, update=False)
self.central_butler.transfer_from(self.butler, datasets, transfer="copy", transfer_dimensions=False)

return datasets

Expand Down
7 changes: 5 additions & 2 deletions tests/test_middleware_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,8 +921,11 @@ def _simulate_run(self):
self.processed_data_id = {(k if k != "exposure" else "visit"): v for k, v in self.raw_data_id.items()}
self.second_processed_data_id = {(k if k != "exposure" else "visit"): v
for k, v in self.second_data_id.items()}
# Dataset types defined for local Butler on pipeline run, but no
# guarantee this happens in central Butler.
# Dataset types defined for local Butler on pipeline run, but code
# assumes output types already exist in central repo.
butler_tests.addDatasetType(self.interface.central_butler, "calexp",
{"instrument", "visit", "detector"},
"ExposureF")
butler_tests.addDatasetType(self.interface.butler, "calexp",
{"instrument", "visit", "detector"},
"ExposureF")
Expand Down