diff --git a/python/activator/middleware_interface.py b/python/activator/middleware_interface.py index bca723d7..1603bc18 100644 --- a/python/activator/middleware_interface.py +++ b/python/activator/middleware_interface.py @@ -31,6 +31,8 @@ import tempfile import typing +import astropy + from lsst.resources import ResourcePath import lsst.afw.cameraGeom from lsst.ctrl.mpexec import SeparablePipelineExecutor @@ -45,6 +47,9 @@ _log = logging.getLogger("lsst." + __name__) _log.setLevel(logging.DEBUG) +# See https://developer.lsst.io/stack/logging.html#logger-trace-verbosity +_log_trace = logging.getLogger("TRACE1.lsst." + __name__) +_log_trace.setLevel(logging.CRITICAL) # Turn off by default. def get_central_butler(central_repo: str, instrument_class: str): @@ -503,6 +508,8 @@ def _export_calibs(self, export, detector_id, filter): # TODO: we can't filter by validity range because it's not # supported in queryDatasets yet. calib_where = f"detector={detector_id} and physical_filter='{filter}'" + # private_sndStamp is in TAI, not UTC, but difference shouldn't matter + calib_date = datetime.datetime.fromtimestamp(self.visit.private_sndStamp, tz=datetime.timezone.utc) # TODO: we can't use findFirst=True yet because findFirst query # in CALIBRATION-type collection is not supported currently. calibs = set(_filter_datasets( @@ -510,7 +517,9 @@ def _export_calibs(self, export, detector_id, filter): ..., collections=self.instrument.makeCalibrationCollectionName(), instrument=self.instrument.getName(), - where=calib_where)) + where=calib_where, + calib_date=calib_date, + )) if calibs: for dataset_type, n_datasets in self._count_by_type(calibs): _log.debug("Found %d new calib datasets of type '%s'.", n_datasets, dataset_type) @@ -916,8 +925,11 @@ class _MissingDatasetError(RuntimeError): pass -def _filter_datasets(src_repo: Butler, dest_repo: Butler, - *args, **kwargs) -> collections.abc.Iterable[lsst.daf.butler.DatasetRef]: +def _filter_datasets(src_repo: Butler, + dest_repo: Butler, + *args, + calib_date: datetime.datetime | None = None, + **kwargs) -> collections.abc.Iterable[lsst.daf.butler.DatasetRef]: """Identify datasets in a source repository, filtering out those already present in a destination. @@ -930,6 +942,10 @@ def _filter_datasets(src_repo: Butler, dest_repo: Butler, The repository in which a dataset must be present. dest_repo : `lsst.daf.butler.Butler` The repository in which a dataset must not be present. + calib_date : `datetime.datetime`, optional + If provided, also filter anything other than calibs valid at + ``calib_date`` and check that at least one valid calib was found. + Any ``datetime`` object must be aware. *args, **kwargs Parameters for describing the dataset query. They have the same meanings as the parameters of `lsst.daf.butler.Registry.queryDatasets`. @@ -943,7 +959,8 @@ def _filter_datasets(src_repo: Butler, dest_repo: Butler, Raises ------ _MissingDatasetError - Raised if the query on ``src_repo`` failed to find any datasets. + Raised if the query on ``src_repo`` failed to find any datasets, or + (if ``calib_date`` is set) if none of them are currently valid. """ try: known_datasets = set(dest_repo.registry.queryDatasets(*args, **kwargs)) @@ -958,6 +975,13 @@ def _filter_datasets(src_repo: Butler, dest_repo: Butler, # Let exceptions from src_repo query raise: if it fails, that invalidates # this operation. src_datasets = set(src_repo.registry.queryDatasets(*args, **kwargs)) + if calib_date: + src_datasets = _filter_calibs_by_date( + src_repo, + kwargs["collections"] if "collections" in kwargs else ..., + src_datasets, + calib_date, + ) if not src_datasets: raise _MissingDatasetError( "Source repo query with args '{}, {}' found no matches.".format( @@ -1010,3 +1034,53 @@ def _remove_from_chain(butler: Butler, chain: str, old_collections: collections. for old in set(old_collections).intersection(contents): contents.remove(old) butler.registry.setCollectionChain(chain, contents, flatten=False) + + +def _filter_calibs_by_date(butler: Butler, + collections: typing.Any, + unfiltered_calibs: collections.abc.Collection[lsst.daf.butler.DatasetRef], + date: datetime.datetime + ) -> collections.abc.Iterable[lsst.daf.butler.DatasetRef]: + """Trim a set of calib datasets to those that are valid at a particular time. + + Parameters + ---------- + butler : `lsst.daf.butler.Butler` + The Butler to query for validity data. + collections : collection expression + The calibration collection(s), or chain(s) containing calibration + collections, to query for validity data. + unfiltered_calibs : collection [`lsst.daf.butler.DatasetRef`] + The calibs to be filtered by validity. May be empty. + date : `datetime.datetime` + The time at which the calibs must be valid. Must be an + aware ``datetime``. + + Returns + ------- + filtered_calibs : iterable [`lsst.daf.butler.DatasetRef`] + The subset of ``unfiltered_calibs`` that is valid on ``date``. + """ + dataset_types = {ref.datasetType for ref in unfiltered_calibs} + associations = {} + for dataset_type in dataset_types: + associations.update( + (a.ref, a) for a in butler.registry.queryDatasetAssociations( + dataset_type, collections, collectionTypes={CollectionType.CALIBRATION}, flattenChains=True + ) + ) + + t = astropy.time.Time(date, scale='utc') + _log_trace.debug("Looking up calibs for %s in %s.", t, collections) + # DatasetAssociation.timespan guaranteed not None + filtered_calibs = [] + for ref in unfiltered_calibs: + if ref in associations: + if associations[ref].timespan.contains(t): + filtered_calibs.append(ref) + _log_trace.debug("%s (valid over %s) matches %s.", ref, associations[ref].timespan, t) + else: + _log_trace.debug("%s (valid over %s) does not match %s.", ref, associations[ref].timespan, t) + else: + _log_trace.debug("No calib associations for %s.", ref) + return filtered_calibs diff --git a/tests/test_middleware_interface.py b/tests/test_middleware_interface.py index aca0b7ab..b8a3b0d4 100644 --- a/tests/test_middleware_interface.py +++ b/tests/test_middleware_interface.py @@ -26,6 +26,7 @@ import os.path import unittest import unittest.mock +import warnings import astropy.coordinates import astropy.units as u @@ -43,7 +44,7 @@ from activator.config import PipelinesConfig from activator.visit import FannedOutVisit from activator.middleware_interface import get_central_butler, make_local_repo, MiddlewareInterface, \ - _filter_datasets, _prepend_collection, _remove_from_chain, _MissingDatasetError + _filter_datasets, _prepend_collection, _remove_from_chain, _filter_calibs_by_date, _MissingDatasetError # The short name of the instrument used in the test repo. instname = "DECam" @@ -162,7 +163,7 @@ def setUp(self): dome=FannedOutVisit.Dome.OPEN, duration=35.0, totalCheckpoints=1, - private_sndStamp=1_674_516_794.0, + private_sndStamp=1424237298.7165175, ) self.logger_name = "lsst.activator.middleware_interface" self.interface = MiddlewareInterface(self.central_butler, self.input_data, self.next_visit, @@ -213,7 +214,7 @@ def test_init(self): self.assertEqual(self.interface.rawIngestTask.config.failFast, True) self.assertEqual(self.interface.rawIngestTask.config.transfer, "copy") - def _check_imports(self, butler, detector, expected_shards): + def _check_imports(self, butler, detector, expected_shards, expected_date): """Test that the butler has the expected supporting data. """ self.assertEqual(butler.get('camera', @@ -241,7 +242,7 @@ def _check_imports(self, butler, detector, expected_shards): # TODO: Have to use the exact run collection, because we can't # query by validity range. # collections=self.umbrella) - collections="DECam/calib/20150218T000000Z") + collections=f"DECam/calib/{expected_date}") ) self.assertTrue( butler.exists('cpFlat', detector=detector, instrument='DECam', @@ -250,7 +251,7 @@ def _check_imports(self, butler, detector, expected_shards): # TODO: Have to use the exact run collection, because we can't # query by validity range. # collections=self.umbrella) - collections="DECam/calib/20150218T000000Z") + collections=f"DECam/calib/{expected_date}") ) # Check that the right templates are in the chained output collection. @@ -274,7 +275,47 @@ def test_prep_butler(self): # TODO DM-34112: check these shards again with some plots, once I've # determined whether ci_hits2015 actually has enough shards. expected_shards = {157394, 157401, 157405} - self._check_imports(self.interface.butler, detector=56, expected_shards=expected_shards) + self._check_imports(self.interface.butler, detector=56, + expected_shards=expected_shards, expected_date="20150218T000000Z") + + def test_prep_butler_olddate(self): + """Test that prep_butler returns only calibs from a particular date range. + """ + self.interface.visit = dataclasses.replace( + self.interface.visit, + private_sndStamp=datetime.datetime.fromisoformat("20150313T000000Z").timestamp(), + ) + self.interface.prep_butler() + + # These shards were identified by plotting the objects in each shard + # on-sky and overplotting the detector corners. + # TODO DM-34112: check these shards again with some plots, once I've + # determined whether ci_hits2015 actually has enough shards. + expected_shards = {157394, 157401, 157405} + with self.assertRaises((AssertionError, lsst.daf.butler.registry.MissingCollectionError)): + # 20150218T000000Z run should not be imported + self._check_imports(self.interface.butler, detector=56, + expected_shards=expected_shards, expected_date="20150218T000000Z") + self._check_imports(self.interface.butler, detector=56, + expected_shards=expected_shards, expected_date="20150313T000000Z") + + # TODO: prep_butler doesn't know what kinds of calibs to expect, so can't + # tell that there are specifically, e.g., no flats. This test should pass + # as-is after DM-40245. + @unittest.expectedFailure + def test_prep_butler_novalid(self): + """Test that prep_butler raises if no calibs are currently valid. + """ + self.interface.visit = dataclasses.replace( + self.interface.visit, + private_sndStamp=datetime.datetime(2050, 1, 1).timestamp(), + ) + + with warnings.catch_warnings(): + # Avoid "dubious year" warnings from using a 2050 date + warnings.simplefilter("ignore", category=astropy.utils.exceptions.ErfaWarning) + with self.assertRaises(_MissingDatasetError): + self.interface.prep_butler() def test_prep_butler_twice(self): """prep_butler should have the correct calibs (and not raise an @@ -293,7 +334,8 @@ def test_prep_butler_twice(self): second_interface.prep_butler() expected_shards = {157394, 157401, 157405} - self._check_imports(second_interface.butler, detector=56, expected_shards=expected_shards) + self._check_imports(second_interface.butler, detector=56, + expected_shards=expected_shards, expected_date="20150218T000000Z") # Third visit with different detector and coordinates. # Only 5, 10, 56, 60 have valid calibs. @@ -309,7 +351,8 @@ def test_prep_butler_twice(self): prefix="file://") third_interface.prep_butler() expected_shards.update({157393, 157395}) - self._check_imports(third_interface.butler, detector=5, expected_shards=expected_shards) + self._check_imports(third_interface.butler, detector=5, + expected_shards=expected_shards, expected_date="20150218T000000Z") def test_ingest_image(self): self.interface.prep_butler() # Ensure raw collections exist. @@ -693,6 +736,56 @@ def test_remove_from_chain(self): _remove_from_chain(butler, "_remove_base", ["_remove2", "_remove3"]) self.assertEqual(list(butler.registry.getCollectionChain("_remove_base")), ["_remove1"]) + def test_filter_calibs_by_date_early(self): + # _filter_calibs_by_date requires a collection, not merely an iterable + all_calibs = list(self.central_butler.registry.queryDatasets("cpBias")) + early_calibs = list(_filter_calibs_by_date( + self.central_butler, "DECam/calib", all_calibs, + datetime.datetime(2015, 2, 26, tzinfo=datetime.timezone.utc) + )) + self.assertEqual(len(early_calibs), 4) + for calib in early_calibs: + self.assertEqual(calib.run, "DECam/calib/20150218T000000Z") + + def test_filter_calibs_by_date_late(self): + # _filter_calibs_by_date requires a collection, not merely an iterable + all_calibs = list(self.central_butler.registry.queryDatasets("cpFlat")) + late_calibs = list(_filter_calibs_by_date( + self.central_butler, "DECam/calib", all_calibs, + datetime.datetime(2015, 3, 16, tzinfo=datetime.timezone.utc) + )) + self.assertEqual(len(late_calibs), 4) + for calib in late_calibs: + self.assertEqual(calib.run, "DECam/calib/20150313T000000Z") + + def test_filter_calibs_by_date_never(self): + # _filter_calibs_by_date requires a collection, not merely an iterable + all_calibs = list(self.central_butler.registry.queryDatasets("cpBias")) + with warnings.catch_warnings(): + # Avoid "dubious year" warnings from using a 2050 date + warnings.simplefilter("ignore", category=astropy.utils.exceptions.ErfaWarning) + future_calibs = list(_filter_calibs_by_date( + self.central_butler, "DECam/calib", all_calibs, + datetime.datetime(2050, 1, 1, tzinfo=datetime.timezone.utc) + )) + self.assertEqual(len(future_calibs), 0) + + def test_filter_calibs_by_date_unbounded(self): + # _filter_calibs_by_date requires a collection, not merely an iterable + all_calibs = set(self.central_butler.registry.queryDatasets(["camera", "crosstalk"])) + valid_calibs = set(_filter_calibs_by_date( + self.central_butler, "DECam/calib", all_calibs, + datetime.datetime(2015, 3, 15, tzinfo=datetime.timezone.utc) + )) + self.assertEqual(valid_calibs, all_calibs) + + def test_filter_calibs_by_date_empty(self): + valid_calibs = set(_filter_calibs_by_date( + self.central_butler, "DECam/calib", [], + datetime.datetime(2015, 3, 15, tzinfo=datetime.timezone.utc) + )) + self.assertEqual(len(valid_calibs), 0) + class MiddlewareInterfaceWriteableTest(unittest.TestCase): """Test the MiddlewareInterface class with faked data. @@ -780,7 +873,7 @@ def setUp(self): dome=FannedOutVisit.Dome.OPEN, duration=35.0, totalCheckpoints=1, - private_sndStamp=1_674_516_794.0, + private_sndStamp=1424237298.716517500, ) self.logger_name = "lsst.activator.middleware_interface"