Skip to content

Commit

Permalink
Implement Registry.getRegion.
Browse files Browse the repository at this point in the history
  • Loading branch information
Pim Schellart authored and Pim Schellart committed Aug 6, 2018
1 parent e316cfc commit 0966e7d
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 44 deletions.
15 changes: 1 addition & 14 deletions python/lsst/daf/butler/butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,21 +196,13 @@ def put(self, obj, datasetRefOrType, dataId=None, producer=None):
----------
obj : `object`
The dataset.
<<<<<<< HEAD
datasetRefOrType : `DatasetRef`, `DatasetType` instance or `str`
When `DatasetRef` the `dataId` should be `None`.
Otherwise the `DatasetType` or name thereof.
dataId : `dict`, optional
An identifier with `DataUnit` names and values.
When `None` a `DatasetRef` should be supplied as the second
argument.
=======
datasetType : `DatasetType` instance or `str`
The `DatasetType`.
dataId : `dict`
A `dict` of `DataUnit` link name, value pairs that label the
`DatasetRef` within a Collection.
>>>>>>> Consistent docstring for DataId.
producer : `Quantum`, optional
The producer.
Expand Down Expand Up @@ -304,15 +296,10 @@ def get(self, datasetRefOrType, dataId=None):
When `DatasetRef` the `dataId` should be `None`.
Otherwise the `DatasetType` or name thereof.
dataId : `dict`
<<<<<<< HEAD
A `dict` of `DataUnit` name, value pairs that label the `DatasetRef`
A `dict` of `DataUnit` link name, value pairs that label the `DatasetRef`
within a Collection.
When `None` a `DatasetRef` should be supplied as the second
argument.
=======
A `dict` of `DataUnit` link name, value pairs that label the
`DatasetRef` within a Collection.
>>>>>>> Consistent docstring for DataId.
Returns
-------
Expand Down
68 changes: 61 additions & 7 deletions python/lsst/daf/butler/core/dataUnit.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class DataUnitJoin:
materialized as views (and thus are also not present
in `Registry._schema._metadata`).
"""

def __init__(self, name, lhs=None, rhs=None, summarizes=None, isView=None, table=None):
self._name = name
self._lhs = lhs
Expand Down Expand Up @@ -243,8 +244,8 @@ class DataUnitRegion:
----------
name : `str`
Name of this `DataUnitRegion`, same as the name of the table.
relates : `tuple` of `str`
Names of the DataUnits in this relationship.
relates : `tuple` of `DataUnit`
The DataUnits in this relationship.
table : `sqlalchemy.Table`, optional
The table to be used for queries.
spatial : `bool`, optional
Expand Down Expand Up @@ -275,6 +276,22 @@ def table(self):
"""
return self._table

@property
def primaryKey(self):
"""Full primary-key column name tuple.
"""
keys = frozenset()
for dataUnit in self.relates:
keys |= dataUnit.primaryKey
return keys

@property
def primaryKeyColumns(self):
"""Dictionary keyed on ``primaryKey`` names with `sqlalchemy.Column`
entries into this `DataUnitRegion` primary table as values (`dict`).
"""
return {name: self.table.columns[name] for name in self.primaryKey}

@property
def regionColumn(self):
"""Table column with encoded region data, ``None`` if table has no
Expand All @@ -298,13 +315,16 @@ class DataUnitRegistry:
Entries in this `dict`-like object represent `DataUnit` instances,
keyed on `DataUnit` names.
"""

def __init__(self):
self._dataUnitNames = None
self._dataUnits = {}
self._dataUnitRegions = {}
self.links = {}
self.constraints = []
self.joins = {}
self._dataUnitsByLinkColumnName = {}
self._spatialDataUnits = frozenset()

@classmethod
def fromConfig(cls, config, builder=None):
Expand Down Expand Up @@ -364,9 +384,7 @@ def getRegionHolder(self, *dataUnitNames):
-------
`DataUnitRegion` or `DataUnit` instance.
"""
if len(dataUnitNames) == 1:
return self[dataUnitNames[0]]
return self._dataUnitRegions[frozenset(dataUnitNames)]
return self._dataUnitRegions[frozenset(dataUnitNames) & self._spatialDataUnits]

def getJoin(self, lhs, rhs):
"""Return the DataUnitJoin that relates the given DataUnit names.
Expand Down Expand Up @@ -462,8 +480,22 @@ def _initDataUnits(self, config, builder):
optionalDependencies=optionalDependencies,
table=table,
link=link,
regionColumn=regionColumn)
spatial=spatial)
self[dataUnitName] = dataUnit
for linkColumnName in link:
self._dataUnitsByLinkColumnName[linkColumnName] = dataUnit
if spatial is not None:
self._spatialDataUnits |= frozenset((dataUnitName, ))
# The DataUnit (or DataUnitRegion) instance that can be used
# to retreive the region is keyed based on the union
# of the DataUnit and its required dependencies that are also spatial.
# E.g. 'Patch' is keyed on ('Tract', 'Patch').
# This requires that DataUnit's are visited in topologically sorted order
# (which they are).
key = frozenset((dataUnitName, ) +
tuple(d.name for d in dataUnit.requiredDependencies
if d.name in self._spatialDataUnits))
self._dataUnitRegions[key] = dataUnit

def _initDataUnitRegions(self, config, builder):
"""Initialize tables that associate regions with multiple DataUnits.
Expand All @@ -481,12 +513,13 @@ def _initDataUnitRegions(self, config, builder):
if builder is not None:
table = builder.addTable(tableName, tableDescription)
duRegion = DataUnitRegion(name=tableName,
relates=tuple(description["relates"]),
relates=frozenset(self[name] for name in dataUnitNames),
table=table,
spatial=description.get("spatial", False))
else:
duRegion = None
self._dataUnitRegions[dataUnitNames] = duRegion
self._spatialDataUnits |= frozenset(dataUnitNames)

def _initDataUnitJoins(self, config, builder):
"""Initialize `DataUnit` join entries.
Expand Down Expand Up @@ -531,3 +564,24 @@ def getPrimaryKeyNames(self, dataUnitNames):
All primary-key column names for the given ``dataUnitNames``.
"""
return set(chain.from_iterable(self[name].primaryKey for name in dataUnitNames))

def getByLinkName(self, name):
"""Get a `DataUnit` for which ``name`` is part of the link.
Parameters
----------
name : `str`
Link name.
Returns
-------
dataUnit : `DataUnit`
The corresponding `DataUnit` instance.
Raises
------
KeyError
When the provided ``name`` does not correspond to a link
for any of the `DataUnit` entries in the registry.
"""
return self._dataUnitsByLinkColumnName[name]
6 changes: 3 additions & 3 deletions python/lsst/daf/butler/core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ class DatasetRef:
datasetType : `DatasetType`
The `DatasetType` for this Dataset.
dataId : `dict`
Dictionary where the keys are `DataUnit` names and the values are
`DataUnit` values.
A `dict` of `DataUnit` link name, value pairs that label the
`DatasetRef` within a Collection.
id : `int`, optional
A unique identifier.
Normally set to `None` and assigned by `Registry`
Expand Down Expand Up @@ -200,7 +200,7 @@ def datasetType(self):

@property
def dataId(self):
"""A `dict` of `DataUnit` name, value pairs that label the `DatasetRef`
"""A `dict` of `DataUnit` link name, value pairs that label the `DatasetRef`
within a Collection.
"""
return self._dataId
Expand Down
74 changes: 54 additions & 20 deletions python/lsst/daf/butler/registries/sqlRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from sqlalchemy.sql import select, and_, exists
from sqlalchemy.exc import IntegrityError

from lsst.sphgeom import ConvexPolygon

from ..core.utils import transactional

from ..core.datasets import DatasetType, DatasetRef
Expand Down Expand Up @@ -174,7 +176,7 @@ def registerDatasetType(self, datasetType):
if datasetType.dataUnits:
self._connection.execute(datasetTypeUnitsTable.insert(),
[{'dataset_type_name': datasetType.name, 'unit_name': dataUnitName}
for dataUnitName in datasetType.dataUnits])
for dataUnitName in datasetType.dataUnits])
self._datasetTypes[datasetType.name] = datasetType
# Also register component DatasetTypes (if any)
for compName, compStorageClass in datasetType.storageClass.components.items():
Expand Down Expand Up @@ -817,10 +819,11 @@ def addDataUnitEntry(self, dataUnitName, values):
except IntegrityError as err:
raise ValueError(str(err)) # TODO this should do an explicit validity check instead
if region is not None:
self.setDataUnitRegion((dataUnitName,), v, region, new=True)
self.setDataUnitRegion(
(dataUnitName,) + tuple(d.name for d in dataUnit.requiredDependencies), v, region)

@transactional
def setDataUnitRegion(self, dataUnitNames, value, region, new=True):
def setDataUnitRegion(self, dataUnitNames, value, region, update=True):
"""Set the region field for a DataUnit instance or a combination
thereof and update associated spatial join tables.
Expand All @@ -833,33 +836,30 @@ def setDataUnitRegion(self, dataUnitNames, value, region, new=True):
A dictionary of values that uniquely identify the DataUnits.
region : `sphgeom.ConvexPolygon`
Region on the sky.
new : `bool`
If True, the DataUnits associated identified are being inserted for
the first time, so no spatial regions should already exist.
If False, existing region information for these DataUnits is being
replaced.
update : `bool`
If True, existing region information for these DataUnits is being
replaced. This is usually required because DataUnit entries are
assumed to be pre-inserted prior to calling this function.
"""
keyColumns = {}
primaryKey = set()
for dataUnitName in dataUnitNames:
dataUnit = self._schema.dataUnits[dataUnitName]
dataUnit.validateId(value)
keyColumns.update(dataUnit.primaryKeyColumns)
primaryKey.update(dataUnit.primaryKey)
table = self._schema.dataUnits.getRegionHolder(*dataUnitNames).table
if table is None:
raise TypeError("No region table found for '{}'.".format(dataUnitNames))
# If a region record for these DataUnits already exists, use an update
# query. That could happen either because those DataUnits have been
# inserted previously and this is an improved region for them, or
# because the region is associated with a single DataUnit instance and
# is hence part of that DataUnit's main table.
if not new or (len(dataUnitNames) == 1 and table == dataUnit.table):
self._connection.execute(
# Update the region for an existing entry
if update:
result = self._connection.execute(
table.update().where(
and_((keyColumns[name] == value[name] for name in keyColumns))
and_((table.columns[name] == value[name] for name in primaryKey))
).values(
region=region.encode()
)
)
if result.rowcount == 0:
raise ValueError("No records were updated when setting region, did you forget update=False?")
else: # Insert rather than update.
self._connection.execute(
table.insert().values(
Expand All @@ -871,11 +871,11 @@ def setDataUnitRegion(self, dataUnitNames, value, region, new=True):
join = self._schema.dataUnits.getJoin(dataUnitNames, "SkyPix")
if join is None or join.isView:
return
if not new:
if update:
# Delete any old SkyPix join entries for this DataUnit
self._connection.execute(
join.table.delete().where(
and_((keyColumns[name] == value[name] for name in keyColumns))
and_((join.table.columns[name] == value[name] for name in primaryKey))
)
)
parameters = []
Expand Down Expand Up @@ -993,6 +993,40 @@ def find(self, collection, datasetType, dataId):
else:
return None

def getRegion(self, dataId):
"""Get region associated with a dataId.
Parameters
----------
dataId : `dict`
A `dict` of `DataUnit` link name, value pairs that label the
`DatasetRef` within a Collection.
Returns
-------
region : `lsst.sphgeom.ConvexPolygon`
The region associated with a ``dataId`` or ``None`` if not present.
Raises
------
KeyError
If the set of dataunits for the ``dataId`` does not correspond to
a unique spatial lookup.
"""
dataUnitNames = (self._schema.dataUnits.getByLinkName(linkName).name for linkName in dataId)
regionHolder = self._schema.dataUnits.getRegionHolder(*tuple(dataUnitNames))
# Skypix does not have a table to lookup the region in, instead generate it
if regionHolder == self._schema.dataUnits["SkyPix"]:
return self.pixelization.pixel(dataId["skypix"])
# Lookup region
primaryKeyColumns = regionHolder.primaryKeyColumns
result = self._connection.execute(select([regionHolder.regionColumn]).where(
and_((primaryKeyColumns[name] == dataId[name] for name in primaryKeyColumns)))).fetchone()
if result is not None:
return ConvexPolygon.decode(result[0])
else:
return None

@transactional
def subset(self, collection, expr, datasetTypes):
r"""Create a new `Collection` by subsetting an existing one.
Expand Down

0 comments on commit 0966e7d

Please sign in to comment.