Skip to content

Commit

Permalink
Merge pull request #103 from lsst/tickets/DM-16482
Browse files Browse the repository at this point in the history
DM-16482: Implement support for ExposureRange unit
  • Loading branch information
andy-slac committed Nov 16, 2018
2 parents 3e8a0f8 + 204b45c commit 092a88e
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 12 deletions.
54 changes: 42 additions & 12 deletions python/lsst/daf/butler/registries/sqlPreFlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import itertools
import logging
from sqlalchemy.sql import select, and_, functions, text, literal, case
from sqlalchemy.sql import select, and_, functions, text, literal, case, between

from lsst.sphgeom import Region
from lsst.sphgeom.relationship import DISJOINT
Expand Down Expand Up @@ -211,7 +211,7 @@ def selectDataUnits(self, originInfo, expression, neededDatasetTypes, futureData
unitLinkColumns = {}
for unitName in allUnitNames:
dataUnit = self._dataUnits[unitName]
if self._schema.tables[unitName] is not None:
if self._schema.tables.get(unitName) is not None:
# take link column names, usually there is one
for link in dataUnit.link:
unitLinkColumns[link] = len(selectColumns)
Expand All @@ -230,7 +230,7 @@ def selectDataUnits(self, originInfo, expression, neededDatasetTypes, futureData
# joins for all unit tables
fromJoin = None
for dataUnit in _unitsTopologicalSort(allDataUnits.values()):
if self._schema.tables[dataUnit.name] is None:
if self._schema.tables.get(dataUnit.name) is None:
continue
_LOG.debug("add dataUnit: %s", dataUnit.name)
fromJoin = self._joinOnForeignKey(fromJoin, dataUnit, dataUnit.dependencies)
Expand All @@ -248,14 +248,17 @@ def selectDataUnits(self, originInfo, expression, neededDatasetTypes, futureData
joinedRegionTables = set()
regionColumns = {}
for dataUnitJoin in dataUnitJoins:
_LOG.debug("processing dataUnitJoin: %s", dataUnitJoin.name)
# Some `DataUnitJoin`s have an associated region (e.g. they are spatial)
# in that case they shouldn't be joined separately in the region lookup.
if dataUnitJoin.spatial:
_LOG.debug("%s is spatial, skipping", dataUnitJoin.name)
continue

# TODO: do not know yet how to handle MultiInstrumentExposureJoin,
# skip it for now
if dataUnitJoin.lhs == dataUnitJoin.rhs:
_LOG.debug("%s is unsupported, skipping", dataUnitJoin.name)
continue

# Look at each side of the DataUnitJoin and join it with
Expand All @@ -270,7 +273,12 @@ def selectDataUnits(self, originInfo, expression, neededDatasetTypes, futureData
units.append(dataUnitName)
dataUnit = self._dataUnits[dataUnitName]
units += [d.name for d in dataUnit.requiredDependencies if d.spatial]
regionHolder = self._dataUnits.getRegionHolder(*units)
try:
regionHolder = self._dataUnits.getRegionHolder(*units)
except KeyError:
# means there is no region for these units, want to skip it
_LOG.debug("Units %s are not spatial, skipping", units)
break
if len(connection) > 1:
# if one of the joins is with Visit/Detector then also bring
# VisitDetectorRegion table in and join it with the units
Expand All @@ -294,7 +302,8 @@ def selectDataUnits(self, originInfo, expression, neededDatasetTypes, futureData
regionColumn = self._schema.tables[regionHolder.name].c.region
selectColumns.append(regionColumn)

fromJoin = self._joinOnForeignKey(fromJoin, dataUnitJoin, regionHolders)
if regionHolders:
fromJoin = self._joinOnForeignKey(fromJoin, dataUnitJoin, regionHolders)

# join with input datasets to restrict to existing inputs
dsIdColumns = {}
Expand All @@ -319,9 +328,23 @@ def selectDataUnits(self, originInfo, expression, neededDatasetTypes, futureData
joinOn = []
for unitName in dsType.dataUnits:
dataUnit = allDataUnits[unitName]
for link in dataUnit.link:
_LOG.debug(" joining on link: %s", link)
joinOn.append(subquery.c[link] == self._schema.tables[dataUnit.name].c[link])
if unitName == "ExposureRange":
# very special handling of ExposureRange
# TODO: try to generalize this in some way, maybe using
# sql from ExposureRangeJoin
_LOG.debug(" joining on unit: %s", unitName)
exposureTable = self._schema.tables["Exposure"]
joinOn.append(between(exposureTable.c.datetime_begin,
subquery.c.valid_first,
subquery.c.valid_last))
unitLinkColumns[dsType.name + ".valid_first"] = len(selectColumns)
selectColumns.append(subquery.c.valid_first)
unitLinkColumns[dsType.name + ".valid_last"] = len(selectColumns)
selectColumns.append(subquery.c.valid_last)
else:
for link in dataUnit.link:
_LOG.debug(" joining on link: %s", link)
joinOn.append(subquery.c[link] == self._schema.tables[dataUnit.name].c[link])
fromJoin = fromJoin.join(subquery, and_(*joinOn), isouter=isOutput)

# remember dataset_id column index for this dataset
Expand Down Expand Up @@ -545,12 +568,19 @@ def _convertResultRows(self, rowIter, unitLinkColumns, regionColumns, dsIdColumn
# for each dataset get ids DataRef
datasetRefs = {}
for dsType, col in dsIdColumns.items():
linkNames = set()
linkNames = {} # maps full link name in unitLinkColumns to dataId key
for unitName in dsType.dataUnits:
dataUnit = self._dataUnits[unitName]
if self._schema.tables[dataUnit.name] is not None:
linkNames.update(dataUnit.link)
dsDataId = dict((link, row[unitLinkColumns[link]]) for link in linkNames)
if unitName == "ExposureRange":
# special case of ExposureRange, its columns come from
# Dataset table instead of DataUnit
linkNames[dsType.name + ".valid_first"] = "valid_first"
linkNames[dsType.name + ".valid_last"] = "valid_last"
else:
if self._schema.tables.get(dataUnit.name) is not None:
for link in dataUnit.link:
linkNames[link] = link
dsDataId = dict((val, row[unitLinkColumns[key]]) for key, val in linkNames.items())
dsId = None if col is None else row[col]
datasetRefs[dsType] = DatasetRef(dsType, dsDataId, dsId)

Expand Down
159 changes: 159 additions & 0 deletions tests/test_sqlPreFlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from datetime import datetime, timedelta
import os
import unittest

Expand Down Expand Up @@ -213,6 +214,164 @@ def testPreFlightInstrumentUnits(self):
self.assertCountEqual(set(row.dataId["visit"] for row in rows), (11,))
self.assertCountEqual(set(row.dataId["detector"] for row in rows), (1, 2, 3))

def testPreFlightExposureRange(self):
"""Test involving only ExposureRange unit"""
registry = self.registry

# need a bunch of units and datasets for test
registry.addDataUnitEntry("Instrument", dict(instrument="DummyCam"))
registry.addDataUnitEntry("PhysicalFilter", dict(instrument="DummyCam",
physical_filter="dummy_r",
abstract_filter="r"))
for detector in (1, 2, 3, 4, 5):
registry.addDataUnitEntry("Detector", dict(instrument="DummyCam", detector=detector))

# make few visits/exposures
now = datetime.now()
timestamps = [] # list of start/end time of each exposure
for visit in (10, 11, 20):
registry.addDataUnitEntry("Visit",
dict(instrument="DummyCam", visit=visit, physical_filter="dummy_r"))
visit_start = now + timedelta(seconds=visit*45)
for exposure in (0, 1):
start = visit_start + timedelta(seconds=15*exposure)
end = start + timedelta(seconds=15)
registry.addDataUnitEntry("Exposure", dict(instrument="DummyCam",
exposure=visit*10+exposure,
visit=visit,
physical_filter="dummy_r",
datetime_begin=start,
datetime_end=end))
timestamps += [(start, end)]
self.assertEqual(len(timestamps), 6)

# dataset types
collection = "test"
run = registry.makeRun(collection=collection)
storageClass = StorageClass("testExposureRange")
registry.storageClasses.registerStorageClass(storageClass)
rawType = DatasetType(name="RAW", dataUnits=("Instrument", "Detector", "Exposure"),
storageClass=storageClass)
registry.registerDatasetType(rawType)
biasType = DatasetType(name="bias", dataUnits=("Instrument", "Detector", "ExposureRange"),
storageClass=storageClass)
registry.registerDatasetType(biasType)
flatType = DatasetType(name="flat",
dataUnits=("Instrument", "Detector", "PhysicalFilter", "ExposureRange"),
storageClass=storageClass)
registry.registerDatasetType(flatType)
calexpType = DatasetType(name="CALEXP", dataUnits=("Instrument", "Visit", "Detector"),
storageClass=storageClass)
registry.registerDatasetType(calexpType)

# add pre-existing raw datasets
for visit in (10, 11, 20):
for exposure in (0, 1):
for detector in (1, 2, 3, 4, 5):
dataId = dict(instrument="DummyCam", exposure=visit*10+exposure, detector=detector)
registry.addDataset(rawType, dataId=dataId, run=run)

# add few bias datasets
for detector in (1, 2, 3, 4, 5):
# from before first exposure to the end of second exposure
dataId = dict(instrument="DummyCam", detector=detector,
valid_first=now-timedelta(seconds=3600),
valid_last=timestamps[1][1])
registry.addDataset(biasType, dataId=dataId, run=run)
# from start of third exposure to the end of last exposure
dataId = dict(instrument="DummyCam", detector=detector,
valid_first=timestamps[2][0],
valid_last=timestamps[-1][1])
registry.addDataset(biasType, dataId=dataId, run=run)

# add few flat datasets, only for subset of detectors and exposures
for detector in (1, 2, 3):
# third and fourth exposures
dataId = dict(instrument="DummyCam", detector=detector,
physical_filter="dummy_r",
valid_first=timestamps[2][0],
valid_last=timestamps[3][1])
registry.addDataset(flatType, dataId=dataId, run=run)
# fifth exposure only
dataId = dict(instrument="DummyCam", detector=detector,
physical_filter="dummy_r",
valid_first=timestamps[4][0],
valid_last=timestamps[5][0]-timedelta(seconds=1))
registry.addDataset(flatType, dataId=dataId, run=run)

# without flat/bias
originInfo = DatasetOriginInfoDef(defaultInputs=[collection], defaultOutput=collection)
rows = self.preFlight.selectDataUnits(originInfo=originInfo,
expression="",
neededDatasetTypes=[rawType],
futureDatasetTypes=[calexpType])
rows = list(rows)
self.assertEqual(len(rows), 6*5) # 6 exposures times 5 detectors
for row in rows:
self.assertCountEqual(row.dataId.keys(), ("instrument", "detector", "exposure", "visit"))
self.assertCountEqual(row.datasetRefs.keys(), (rawType, calexpType))

# use bias
originInfo = DatasetOriginInfoDef(defaultInputs=[collection], defaultOutput=collection)
rows = self.preFlight.selectDataUnits(originInfo=originInfo,
expression="",
neededDatasetTypes=[rawType, biasType],
futureDatasetTypes=[calexpType])
rows = list(rows)
self.assertEqual(len(rows), 6*5) # 6 exposures times 5 detectors
for row in rows:
self.assertCountEqual(row.dataId.keys(),
("instrument", "detector", "exposure", "visit",
"bias.valid_first", "bias.valid_last"))
self.assertCountEqual(row.datasetRefs.keys(), (rawType, biasType, calexpType))

# use flat
rows = self.preFlight.selectDataUnits(originInfo=originInfo,
expression="",
neededDatasetTypes=[rawType, flatType],
futureDatasetTypes=[calexpType])
rows = list(rows)
self.assertEqual(len(rows), 3*3) # 3 exposures times 3 detectors
for row in rows:
self.assertCountEqual(row.dataId.keys(),
("instrument", "detector", "exposure", "visit", "physical_filter",
"flat.valid_first", "flat.valid_last"))
self.assertCountEqual(row.datasetRefs.keys(), (rawType, flatType, calexpType))

# use both bias and flat, plus expression
rows = self.preFlight.selectDataUnits(originInfo=originInfo,
expression="Detector.detector IN (1, 3)",
neededDatasetTypes=[rawType, flatType, biasType],
futureDatasetTypes=[calexpType])
rows = list(rows)
self.assertEqual(len(rows), 3*2) # 3 exposures times 2 detectors
for row in rows:
self.assertCountEqual(row.dataId.keys(),
("instrument", "detector", "exposure", "visit", "physical_filter",
"bias.valid_first", "bias.valid_last",
"flat.valid_first", "flat.valid_last"))
self.assertCountEqual(row.datasetRefs.keys(), (rawType, flatType, biasType, calexpType))

# select single exposure (third) and detector and check datasetRefs
rows = self.preFlight.selectDataUnits(originInfo=originInfo,
expression="Exposure.exposure = 110 AND Detector.detector = 1",
neededDatasetTypes=[rawType, flatType, biasType],
futureDatasetTypes=[calexpType])
rows = list(rows)
self.assertEqual(len(rows), 1)
row = rows[0]
self.assertEqual(row.datasetRefs[flatType].dataId,
dict(instrument="DummyCam",
detector=1,
physical_filter="dummy_r",
valid_first=timestamps[2][0],
valid_last=timestamps[3][1]))
self.assertEqual(row.datasetRefs[biasType].dataId,
dict(instrument="DummyCam",
detector=1,
valid_first=timestamps[2][0],
valid_last=timestamps[-1][1]))

def testPreFlightSkyMapUnits(self):
"""Test involving only SkyMap units, no joins to Instrument"""
registry = self.registry
Expand Down

0 comments on commit 092a88e

Please sign in to comment.