Skip to content

Commit

Permalink
Add methods to extract Astropy table views into LSST Catalogs
Browse files Browse the repository at this point in the history
The __astropy_table__ method is called by the Astropy Table
constructor (in Astropy >= v1.2) to support the syntax:

t = astropy.table.Table(catalog)

We also provide our own direct interface that returns an Astropy Table
(or QTable) as an "asAstropy" method.
  • Loading branch information
TallJimbo committed May 17, 2016
1 parent 908c0dc commit 44cac51
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 0 deletions.
12 changes: 12 additions & 0 deletions python/lsst/afw/table/Base.i
Original file line number Diff line number Diff line change
Expand Up @@ -686,4 +686,16 @@ namespace lsst { namespace afw { namespace table {

%declareCatalog(CatalogT, Base)

// This needs to be here, not Catalog.i, to prevent it from being picked up in afw.detection, where _syntax.py is not available.
%extend CatalogT<BaseRecord> {
%pythoncode %{
asAstropy = _syntax.BaseCatalog_asAstropy

def __astropy_table__(self, cls, copy, **kwds):
"""Implement interface called by Astropy table constructors to construct a view or copy.
"""
return _syntax.BaseCatalog_asAstropy(cls=cls, copy=copy)
%}
}

}}} // namespace lsst::afw::table
67 changes: 67 additions & 0 deletions python/lsst/afw/table/_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,70 @@ def processArray(a):
else:
d[name] = processArray(self.get(schemaItem.key))
return d

def BaseCatalog_asAstropy(self, cls=None, copy=False, unviewable="copy"):
"""!
Return an astropy.table.Table (or subclass thereof) view into this catalog.
@param[in] cls Table subclass to use; None implies astropy.table.Table itself.
Use astropy.table.QTable to get Quantity columns.
@param[in] copy Whether to copy data from the LSST catalog to the astropy table.
Not copying is usually faster, but can keep memory from being
freed if columns are later removed from the Astropy view.
@param[in] unviewable One of the following options, indicating how to handle field types
(string and Flag) for which views cannot be constructed:
- 'copy' (default): copy only the unviewable fields.
- 'raise': raise ValueError if unviewable fields are present.
- 'skip': do not include unviewable fields in the Astropy Table.
This option is ignored if copy=True.
"""
import astropy.table
if cls is None:
cls = astropy.table.Table
if unviewable not in ("copy", "raise", "skip"):
raise ValueError("'unviewable' must be one of 'copy', 'raise', or 'skip'")
ps = self.getMetadata()
meta = ps.toOrderedDict() if ps is not None else None
columns = []
items = self.schema.extract("*", ordered=True)
for name, item in items.iteritems():
key = item.key
unit = item.field.getUnits() or None # use None instead of "" when empty
if key.getTypeString() == "String":
if not copy:
if unviewable == "raise":
raise ValueError("Cannot extract string unless copy=True or unviewable='copy' or 'skip'.")
elif unviewable == "skip":
continue
data = numpy.zeros(len(self), dtype=numpy.dtype((str, key.getSize())))
for i, record in enumerate(self):
data[i] = record.get(key)
elif key.getTypeString() == "Flag":
if not copy:
if unviewable == "raise":
raise ValueError(
"Cannot extract packed bit columns unless copy=True or unviewable='copy' or 'skip'."
)
elif unviewable == "skip":
continue
data = self.columns.get_bool_array(key)
elif key.getTypeString() == "Angle":
data = self.columns.get(key)
unit = "radian"
if copy:
data = data.copy()
else:
data = self.columns.get(key)
if copy:
data = data.copy()
columns.append(
astropy.table.Column(
data,
name=item.field.getName(),
unit=unit,
description=item.field.getDoc()
)
)
return cls(columns, meta=meta, copy=False)
197 changes: 197 additions & 0 deletions tests/testAstropyTableViews.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#!/usr/bin/env python2
from __future__ import absolute_import, division

#
# LSST Data Management System
# Copyright 2016 AURA/LSST
#
# This product includes software developed by the
# LSST Project (http://www.lsst.org/).
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the LSST License Statement and
# the GNU General Public License along with this program. If not,
# see <http://www.lsstcorp.org/LegalNotices/>.
#

"""
Tests for Astropy views into afw.table Catalogs
Run with:
./testAstropyTableViews.py
or
python
>>> import testAstropyTableViews; testAstropyTableViews.run()
"""

import unittest
import operator

import numpy
import astropy.table
import astropy.units

import lsst.utils.tests
import lsst.afw.table


class AstropyTableViewTestCase(lsst.utils.tests.TestCase):
"""Test that we can construct Astropy views to afw.table Catalog objects.
This test case does not yet test the syntax
table = astropy.table.Table(lsst_catalog)
which is made available by BaseCatalog.__astropy_table__, as this will not
be available until Astropy 1.2 is released. However, this simply
delegates to BaseCatalog.asAstropy, which can also be called directly.
"""

def setUp(self):
schema = lsst.afw.table.Schema()
self.k1 = schema.addField("a1", type=float, units="meter", doc="a1 (meter)")
self.k2 = schema.addField("a2", type=int, doc="a2 (unitless)")
self.k3 = schema.addField("a3", type="ArrayF", size=3, units="count", doc="a3 (array, counts)")
self.k4 = schema.addField("a4", type="Flag", doc="a4 (flag)")
self.k5 = lsst.afw.table.CoordKey.addFields(schema, "a5", "a5 coordinate")
self.k6 = schema.addField("a6", type=str, size=8, doc="a6 (str)")
self.catalog = lsst.afw.table.BaseCatalog(schema)
self.data = [
{
"a1": 5.0, "a2": 3, "a3": numpy.array([0.5, 0.0, -0.5], dtype=numpy.float32),
"a4": True, "a5_ra": 45.0*lsst.afw.geom.degrees, "a5_dec": 30.0*lsst.afw.geom.degrees,
"a6": "bubbles"
},
{
"a1": 2.5, "a2": 7, "a3": numpy.array([1.0, 0.5, -1.5], dtype=numpy.float32),
"a4": False, "a5_ra": 25.0*lsst.afw.geom.degrees, "a5_dec": -60.0*lsst.afw.geom.degrees,
"a6": "pingpong"
},
]
for d in self.data:
record = self.catalog.addNew()
for k, v in d.iteritems():
record.set(k, v)

def tearDown(self):
del self.k1
del self.k2
del self.k3
del self.k4
del self.k5
del self.k6
del self.catalog
del self.data

def testQuantityColumn(self):
"""Test that a column with units is handled as expected by Table and QTable.
"""
v1 = self.catalog.asAstropy(cls=astropy.table.Table, unviewable="skip")
self.assertEqual(v1["a1"].unit, astropy.units.Unit("m"))
self.assertClose(v1["a1"], self.catalog["a1"])
self.assertNotIsInstance(v1["a1"], astropy.units.Quantity)
v2 = self.catalog.asAstropy(cls=astropy.table.QTable, unviewable="skip")
self.assertEqual(v2["a1"].unit, astropy.units.Unit("m"))
self.assertClose(v2["a1"]/astropy.units.Quantity(self.catalog["a1"]*100, "cm"), 1.0)
self.assertIsInstance(v2["a1"], astropy.units.Quantity)

def testUnitlessColumn(self):
"""Test that a column without units is handled as expected by Table and QTable.
"""
v1 = self.catalog.asAstropy(cls=astropy.table.Table, unviewable="skip")
self.assertEqual(v1["a2"].unit, None)
self.assertClose(v1["a2"], self.catalog["a2"]) # use assertClose just because it handles arrays
v2 = self.catalog.asAstropy(cls=astropy.table.QTable, unviewable="skip")
self.assertEqual(v2["a2"].unit, None)
self.assertClose(v2["a2"], self.catalog["a2"])

def testArrayColumn(self):
"""Test that an array column appears as a 2-d array with the expected shape.
"""
v = self.catalog.asAstropy(unviewable="skip")
self.assertClose(v["a3"], self.catalog["a3"])

def testFlagColumn(self):
"""Test that Flag columns can be viewed if copy=True or unviewable="copy".
"""
v1 = self.catalog.asAstropy(unviewable="copy")
self.assertClose(v1["a4"], self.catalog["a4"])
v2 = self.catalog.asAstropy(copy=True)
self.assertClose(v2["a4"], self.catalog["a4"])

def testCoordColumn(self):
"""Test that Coord columns appears as a pair of columns with correct angle units.
"""
v1 = self.catalog.asAstropy(cls=astropy.table.Table, unviewable="skip")
self.assertClose(v1["a5_ra"], self.catalog["a5_ra"])
self.assertEqual(v1["a5_ra"].unit, astropy.units.Unit("rad"))
self.assertNotIsInstance(v1["a5_ra"], astropy.units.Quantity)
self.assertClose(v1["a5_dec"], self.catalog["a5_dec"])
self.assertEqual(v1["a5_dec"].unit, astropy.units.Unit("rad"))
self.assertNotIsInstance(v1["a5_dec"], astropy.units.Quantity)
v2 = self.catalog.asAstropy(cls=astropy.table.QTable, unviewable="skip")
self.assertClose(v2["a5_ra"]/astropy.units.Quantity(self.catalog["a5_ra"], unit="rad"), 1.0)
self.assertEqual(v2["a5_ra"].unit, astropy.units.Unit("rad"))
self.assertIsInstance(v2["a5_ra"], astropy.units.Quantity)
self.assertClose(v2["a5_dec"]/astropy.units.Quantity(self.catalog["a5_dec"], unit="rad"), 1.0)
self.assertEqual(v2["a5_dec"].unit, astropy.units.Unit("rad"))
self.assertIsInstance(v2["a5_dec"], astropy.units.Quantity)

def testStringColumn(self):
"""Test that string columns can be viewed if copy=True or unviewable='copy'.
"""
v1 = self.catalog.asAstropy(unviewable="copy")
self.assertEqual(v1["a6"][0], self.data[0]["a6"])
self.assertEqual(v1["a6"][1], self.data[1]["a6"])
v2 = self.catalog.asAstropy(copy=True)
self.assertEqual(v2["a6"][0], self.data[0]["a6"])
self.assertEqual(v2["a6"][1], self.data[1]["a6"])

def testRaiseOnUnviewable(self):
"""Test that we can't view this table without copying, since it has Flag and String columns.
"""
self.assertRaises(ValueError, self.catalog.asAstropy, copy=False, unviewable="raise")

def testNoUnnecessaryCopies(self):
"""Test that fields that aren't Flag or String are not copied when copy=False (the default).
"""
v1 = self.catalog.asAstropy(unviewable="copy")
v1["a2"][0] = 4
self.assertEqual(self.catalog[0]["a2"], 4)
v2 = self.catalog.asAstropy(unviewable="skip")
v2["a2"][1] = 10
self.assertEqual(self.catalog[1]["a2"], 10)

def testUnviewableSkip(self):
"""Test that we can skip unviewable columns.
"""
v1 = self.catalog.asAstropy(unviewable="skip")
self.assertRaises(KeyError, operator.getitem, v1, "a4")
self.assertRaises(KeyError, operator.getitem, v1, "a6")


def suite():
"""Returns a suite containing all the test cases in this module."""

lsst.utils.tests.init()

suites = []
suites += unittest.makeSuite(AstropyTableViewTestCase)
suites += unittest.makeSuite(lsst.utils.tests.MemoryTestCase)
return unittest.TestSuite(suites)

def run(shouldExit = False):
"""Run the tests"""
lsst.utils.tests.run(suite(), shouldExit)

if __name__ == "__main__":
run(True)

0 comments on commit 44cac51

Please sign in to comment.