Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 78 additions & 4 deletions python/activator/middleware_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import tempfile
import typing

import astropy

from lsst.resources import ResourcePath
import lsst.afw.cameraGeom
from lsst.ctrl.mpexec import SeparablePipelineExecutor
Expand All @@ -45,6 +47,9 @@

_log = logging.getLogger("lsst." + __name__)
_log.setLevel(logging.DEBUG)
# See https://developer.lsst.io/stack/logging.html#logger-trace-verbosity
_log_trace = logging.getLogger("TRACE1.lsst." + __name__)
_log_trace.setLevel(logging.CRITICAL) # Turn off by default.


def get_central_butler(central_repo: str, instrument_class: str):
Expand Down Expand Up @@ -503,14 +508,18 @@ def _export_calibs(self, export, detector_id, filter):
# TODO: we can't filter by validity range because it's not
# supported in queryDatasets yet.
calib_where = f"detector={detector_id} and physical_filter='{filter}'"
# private_sndStamp is in TAI, not UTC, but difference shouldn't matter
calib_date = datetime.datetime.fromtimestamp(self.visit.private_sndStamp, tz=datetime.timezone.utc)
# TODO: we can't use findFirst=True yet because findFirst query
# in CALIBRATION-type collection is not supported currently.
calibs = set(_filter_datasets(
self.central_butler, self.butler,
...,
collections=self.instrument.makeCalibrationCollectionName(),
instrument=self.instrument.getName(),
where=calib_where))
where=calib_where,
calib_date=calib_date,
))
if calibs:
for dataset_type, n_datasets in self._count_by_type(calibs):
_log.debug("Found %d new calib datasets of type '%s'.", n_datasets, dataset_type)
Expand Down Expand Up @@ -916,8 +925,11 @@ class _MissingDatasetError(RuntimeError):
pass


def _filter_datasets(src_repo: Butler, dest_repo: Butler,
*args, **kwargs) -> collections.abc.Iterable[lsst.daf.butler.DatasetRef]:
def _filter_datasets(src_repo: Butler,
dest_repo: Butler,
*args,
calib_date: datetime.datetime | None = None,
**kwargs) -> collections.abc.Iterable[lsst.daf.butler.DatasetRef]:
"""Identify datasets in a source repository, filtering out those already
present in a destination.

Expand All @@ -930,6 +942,10 @@ def _filter_datasets(src_repo: Butler, dest_repo: Butler,
The repository in which a dataset must be present.
dest_repo : `lsst.daf.butler.Butler`
The repository in which a dataset must not be present.
calib_date : `datetime.datetime`, optional
If provided, also filter anything other than calibs valid at
``calib_date`` and check that at least one valid calib was found.
Any ``datetime`` object must be aware.
*args, **kwargs
Parameters for describing the dataset query. They have the same
meanings as the parameters of `lsst.daf.butler.Registry.queryDatasets`.
Expand All @@ -943,7 +959,8 @@ def _filter_datasets(src_repo: Butler, dest_repo: Butler,
Raises
------
_MissingDatasetError
Raised if the query on ``src_repo`` failed to find any datasets.
Raised if the query on ``src_repo`` failed to find any datasets, or
(if ``calib_date`` is set) if none of them are currently valid.
"""
try:
known_datasets = set(dest_repo.registry.queryDatasets(*args, **kwargs))
Expand All @@ -958,6 +975,13 @@ def _filter_datasets(src_repo: Butler, dest_repo: Butler,
# Let exceptions from src_repo query raise: if it fails, that invalidates
# this operation.
src_datasets = set(src_repo.registry.queryDatasets(*args, **kwargs))
if calib_date:
src_datasets = _filter_calibs_by_date(
src_repo,
kwargs["collections"] if "collections" in kwargs else ...,
src_datasets,
calib_date,
)
if not src_datasets:
raise _MissingDatasetError(
"Source repo query with args '{}, {}' found no matches.".format(
Expand Down Expand Up @@ -1010,3 +1034,53 @@ def _remove_from_chain(butler: Butler, chain: str, old_collections: collections.
for old in set(old_collections).intersection(contents):
contents.remove(old)
butler.registry.setCollectionChain(chain, contents, flatten=False)


def _filter_calibs_by_date(butler: Butler,
collections: typing.Any,
unfiltered_calibs: collections.abc.Collection[lsst.daf.butler.DatasetRef],
date: datetime.datetime
) -> collections.abc.Iterable[lsst.daf.butler.DatasetRef]:
"""Trim a set of calib datasets to those that are valid at a particular time.

Parameters
----------
butler : `lsst.daf.butler.Butler`
The Butler to query for validity data.
collections : collection expression
The calibration collection(s), or chain(s) containing calibration
collections, to query for validity data.
unfiltered_calibs : collection [`lsst.daf.butler.DatasetRef`]
The calibs to be filtered by validity. May be empty.
date : `datetime.datetime`
The time at which the calibs must be valid. Must be an
aware ``datetime``.

Returns
-------
filtered_calibs : iterable [`lsst.daf.butler.DatasetRef`]
The subset of ``unfiltered_calibs`` that is valid on ``date``.
"""
dataset_types = {ref.datasetType for ref in unfiltered_calibs}
associations = {}
for dataset_type in dataset_types:
associations.update(
(a.ref, a) for a in butler.registry.queryDatasetAssociations(
dataset_type, collections, collectionTypes={CollectionType.CALIBRATION}, flattenChains=True
)
)

t = astropy.time.Time(date, scale='utc')
_log_trace.debug("Looking up calibs for %s in %s.", t, collections)
# DatasetAssociation.timespan guaranteed not None
filtered_calibs = []
for ref in unfiltered_calibs:
if ref in associations:
if associations[ref].timespan.contains(t):
filtered_calibs.append(ref)
_log_trace.debug("%s (valid over %s) matches %s.", ref, associations[ref].timespan, t)
else:
_log_trace.debug("%s (valid over %s) does not match %s.", ref, associations[ref].timespan, t)
else:
_log_trace.debug("No calib associations for %s.", ref)
return filtered_calibs
111 changes: 102 additions & 9 deletions tests/test_middleware_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import os.path
import unittest
import unittest.mock
import warnings

import astropy.coordinates
import astropy.units as u
Expand All @@ -43,7 +44,7 @@
from activator.config import PipelinesConfig
from activator.visit import FannedOutVisit
from activator.middleware_interface import get_central_butler, make_local_repo, MiddlewareInterface, \
_filter_datasets, _prepend_collection, _remove_from_chain, _MissingDatasetError
_filter_datasets, _prepend_collection, _remove_from_chain, _filter_calibs_by_date, _MissingDatasetError

# The short name of the instrument used in the test repo.
instname = "DECam"
Expand Down Expand Up @@ -162,7 +163,7 @@ def setUp(self):
dome=FannedOutVisit.Dome.OPEN,
duration=35.0,
totalCheckpoints=1,
private_sndStamp=1_674_516_794.0,
private_sndStamp=1424237298.7165175,
)
self.logger_name = "lsst.activator.middleware_interface"
self.interface = MiddlewareInterface(self.central_butler, self.input_data, self.next_visit,
Expand Down Expand Up @@ -213,7 +214,7 @@ def test_init(self):
self.assertEqual(self.interface.rawIngestTask.config.failFast, True)
self.assertEqual(self.interface.rawIngestTask.config.transfer, "copy")

def _check_imports(self, butler, detector, expected_shards):
def _check_imports(self, butler, detector, expected_shards, expected_date):
"""Test that the butler has the expected supporting data.
"""
self.assertEqual(butler.get('camera',
Expand Down Expand Up @@ -241,7 +242,7 @@ def _check_imports(self, butler, detector, expected_shards):
# TODO: Have to use the exact run collection, because we can't
# query by validity range.
# collections=self.umbrella)
collections="DECam/calib/20150218T000000Z")
collections=f"DECam/calib/{expected_date}")
)
self.assertTrue(
butler.exists('cpFlat', detector=detector, instrument='DECam',
Expand All @@ -250,7 +251,7 @@ def _check_imports(self, butler, detector, expected_shards):
# TODO: Have to use the exact run collection, because we can't
# query by validity range.
# collections=self.umbrella)
collections="DECam/calib/20150218T000000Z")
collections=f"DECam/calib/{expected_date}")
)

# Check that the right templates are in the chained output collection.
Expand All @@ -274,7 +275,47 @@ def test_prep_butler(self):
# TODO DM-34112: check these shards again with some plots, once I've
# determined whether ci_hits2015 actually has enough shards.
expected_shards = {157394, 157401, 157405}
self._check_imports(self.interface.butler, detector=56, expected_shards=expected_shards)
self._check_imports(self.interface.butler, detector=56,
expected_shards=expected_shards, expected_date="20150218T000000Z")

def test_prep_butler_olddate(self):
"""Test that prep_butler returns only calibs from a particular date range.
"""
self.interface.visit = dataclasses.replace(
self.interface.visit,
private_sndStamp=datetime.datetime.fromisoformat("20150313T000000Z").timestamp(),
)
self.interface.prep_butler()

# These shards were identified by plotting the objects in each shard
# on-sky and overplotting the detector corners.
# TODO DM-34112: check these shards again with some plots, once I've
# determined whether ci_hits2015 actually has enough shards.
expected_shards = {157394, 157401, 157405}
with self.assertRaises((AssertionError, lsst.daf.butler.registry.MissingCollectionError)):
# 20150218T000000Z run should not be imported
self._check_imports(self.interface.butler, detector=56,
expected_shards=expected_shards, expected_date="20150218T000000Z")
self._check_imports(self.interface.butler, detector=56,
expected_shards=expected_shards, expected_date="20150313T000000Z")

# TODO: prep_butler doesn't know what kinds of calibs to expect, so can't
# tell that there are specifically, e.g., no flats. This test should pass
# as-is after DM-40245.
@unittest.expectedFailure
def test_prep_butler_novalid(self):
"""Test that prep_butler raises if no calibs are currently valid.
"""
self.interface.visit = dataclasses.replace(
self.interface.visit,
private_sndStamp=datetime.datetime(2050, 1, 1).timestamp(),
)

with warnings.catch_warnings():
# Avoid "dubious year" warnings from using a 2050 date
warnings.simplefilter("ignore", category=astropy.utils.exceptions.ErfaWarning)
with self.assertRaises(_MissingDatasetError):
self.interface.prep_butler()

def test_prep_butler_twice(self):
"""prep_butler should have the correct calibs (and not raise an
Expand All @@ -293,7 +334,8 @@ def test_prep_butler_twice(self):

second_interface.prep_butler()
expected_shards = {157394, 157401, 157405}
self._check_imports(second_interface.butler, detector=56, expected_shards=expected_shards)
self._check_imports(second_interface.butler, detector=56,
expected_shards=expected_shards, expected_date="20150218T000000Z")

# Third visit with different detector and coordinates.
# Only 5, 10, 56, 60 have valid calibs.
Expand All @@ -309,7 +351,8 @@ def test_prep_butler_twice(self):
prefix="file://")
third_interface.prep_butler()
expected_shards.update({157393, 157395})
self._check_imports(third_interface.butler, detector=5, expected_shards=expected_shards)
self._check_imports(third_interface.butler, detector=5,
expected_shards=expected_shards, expected_date="20150218T000000Z")

def test_ingest_image(self):
self.interface.prep_butler() # Ensure raw collections exist.
Expand Down Expand Up @@ -693,6 +736,56 @@ def test_remove_from_chain(self):
_remove_from_chain(butler, "_remove_base", ["_remove2", "_remove3"])
self.assertEqual(list(butler.registry.getCollectionChain("_remove_base")), ["_remove1"])

def test_filter_calibs_by_date_early(self):
# _filter_calibs_by_date requires a collection, not merely an iterable
all_calibs = list(self.central_butler.registry.queryDatasets("cpBias"))
early_calibs = list(_filter_calibs_by_date(
self.central_butler, "DECam/calib", all_calibs,
datetime.datetime(2015, 2, 26, tzinfo=datetime.timezone.utc)
))
self.assertEqual(len(early_calibs), 4)
for calib in early_calibs:
self.assertEqual(calib.run, "DECam/calib/20150218T000000Z")

def test_filter_calibs_by_date_late(self):
# _filter_calibs_by_date requires a collection, not merely an iterable
all_calibs = list(self.central_butler.registry.queryDatasets("cpFlat"))
late_calibs = list(_filter_calibs_by_date(
self.central_butler, "DECam/calib", all_calibs,
datetime.datetime(2015, 3, 16, tzinfo=datetime.timezone.utc)
))
self.assertEqual(len(late_calibs), 4)
for calib in late_calibs:
self.assertEqual(calib.run, "DECam/calib/20150313T000000Z")

def test_filter_calibs_by_date_never(self):
# _filter_calibs_by_date requires a collection, not merely an iterable
all_calibs = list(self.central_butler.registry.queryDatasets("cpBias"))
with warnings.catch_warnings():
# Avoid "dubious year" warnings from using a 2050 date
warnings.simplefilter("ignore", category=astropy.utils.exceptions.ErfaWarning)
future_calibs = list(_filter_calibs_by_date(
self.central_butler, "DECam/calib", all_calibs,
datetime.datetime(2050, 1, 1, tzinfo=datetime.timezone.utc)
))
self.assertEqual(len(future_calibs), 0)

def test_filter_calibs_by_date_unbounded(self):
# _filter_calibs_by_date requires a collection, not merely an iterable
all_calibs = set(self.central_butler.registry.queryDatasets(["camera", "crosstalk"]))
valid_calibs = set(_filter_calibs_by_date(
self.central_butler, "DECam/calib", all_calibs,
datetime.datetime(2015, 3, 15, tzinfo=datetime.timezone.utc)
))
self.assertEqual(valid_calibs, all_calibs)

def test_filter_calibs_by_date_empty(self):
valid_calibs = set(_filter_calibs_by_date(
self.central_butler, "DECam/calib", [],
datetime.datetime(2015, 3, 15, tzinfo=datetime.timezone.utc)
))
self.assertEqual(len(valid_calibs), 0)


class MiddlewareInterfaceWriteableTest(unittest.TestCase):
"""Test the MiddlewareInterface class with faked data.
Expand Down Expand Up @@ -780,7 +873,7 @@ def setUp(self):
dome=FannedOutVisit.Dome.OPEN,
duration=35.0,
totalCheckpoints=1,
private_sndStamp=1_674_516_794.0,
private_sndStamp=1424237298.716517500,
)
self.logger_name = "lsst.activator.middleware_interface"

Expand Down