Skip to content

Commit

Permalink
put tester function into base class
Browse files Browse the repository at this point in the history
  • Loading branch information
n8pease committed Nov 3, 2020
1 parent 3a80090 commit d4e8cdd
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 42 deletions.
56 changes: 29 additions & 27 deletions python/lsst/daf/butler/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,33 @@
import io


def assertAstropyTablesEqual(testCase, tables, expectedTables):
"""Verify that a list of astropy tables matches a list of expected
astropy tables.
class ButlerTestHelper:
"""Mixin with helpers for unit tests."""

Parameters
----------
testCase : `unittest.TestCase`
The test case performing the assert.
tables : `astropy.table.Table` or iterable [`astropy.table.Table`]
The table or tables that should match the expected tables.
expectedTables : `astropy.table.Table` or iterable [`astropy.table.Table`]
The tables with expected values to which the tables under test will be
compared.
"""
# If a single table is passed in for tables or expectedTables, put it in a
# list.
if isinstance(tables, AstropyTable):
tables = [tables]
if isinstance(expectedTables, AstropyTable):
expectedTables = [expectedTables]
diff = io.StringIO()
testCase.assertEqual(len(tables), len(expectedTables))
for table, expected in zip(tables, expectedTables):
# Assert that we are testing what we think we are testing:
testCase.assertIsInstance(table, AstropyTable)
testCase.assertIsInstance(expected, AstropyTable)
# Assert that they match:
testCase.assertTrue(report_diff_values(table, expected, fileobj=diff), msg="\n" + diff.getvalue())
def assertAstropyTablesEqual(self, tables, expectedTables):
"""Verify that a list of astropy tables matches a list of expected
astropy tables.
Parameters
----------
tables : `astropy.table.Table` or iterable [`astropy.table.Table`]
The table or tables that should match the expected tables.
expectedTables : `astropy.table.Table`
or iterable [`astropy.table.Table`]
The tables with expected values to which the tables under test will
be compared.
"""
# If a single table is passed in for tables or expectedTables, put it
# in a list.
if isinstance(tables, AstropyTable):
tables = [tables]
if isinstance(expectedTables, AstropyTable):
expectedTables = [expectedTables]
diff = io.StringIO()
self.assertEqual(len(tables), len(expectedTables))
for table, expected in zip(tables, expectedTables):
# Assert that we are testing what we think we are testing:
self.assertIsInstance(table, AstropyTable)
self.assertIsInstance(expected, AstropyTable)
# Assert that they match:
self.assertTrue(report_diff_values(table, expected, fileobj=diff), msg="\n" + diff.getvalue())
14 changes: 7 additions & 7 deletions tests/test_cliCmdQueryDataIds.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@
)
from lsst.daf.butler import script
from lsst.daf.butler.tests import addDatasetType, MetricsExample
from lsst.daf.butler.tests.utils import assertAstropyTablesEqual
from lsst.daf.butler.tests.utils import ButlerTestHelper


TESTDIR = os.path.abspath(os.path.dirname(__file__))


class QueryDataIdsTest(unittest.TestCase):
class QueryDataIdsTest(unittest.TestCase, ButlerTestHelper):

configFile = os.path.join(TESTDIR, "config/basic/butler.yaml")
storageClassFactory = StorageClassFactory()
Expand Down Expand Up @@ -131,7 +131,7 @@ def testDimensions(self):
)),
names=("band", "instrument", "physical_filter", "visit_system", "visit")
)
assertAstropyTablesEqual(self, res, expected)
self.assertAstropyTablesEqual(res, expected)

def testNull(self):
"Test asking for nothing."
Expand All @@ -148,7 +148,7 @@ def testDatasets(self):
)),
names=("band", "instrument", "physical_filter", "visit_system", "visit")
)
assertAstropyTablesEqual(self, res, expected)
self.assertAstropyTablesEqual(res, expected)

def testWhere(self):
"""Test getting datasets."""
Expand All @@ -159,7 +159,7 @@ def testWhere(self):
)),
names=("band", "instrument", "physical_filter", "visit_system", "visit")
)
assertAstropyTablesEqual(self, res, expected)
self.assertAstropyTablesEqual(res, expected)

def testCollections(self):
"""Test getting datasets using the collections option."""
Expand All @@ -184,7 +184,7 @@ def testCollections(self):
)),
names=("band", "instrument", "physical_filter", "visit_system", "visit")
)
assertAstropyTablesEqual(self, res, expected)
self.assertAstropyTablesEqual(res, expected)

# Verify the new dataset is found in the "foo" collection.
res = self._queryDataIds(repo=self.root, dimensions=("visit",), collections=("foo",),
Expand All @@ -195,7 +195,7 @@ def testCollections(self):
)),
names=("band", "instrument", "physical_filter", "visit_system", "visit")
)
assertAstropyTablesEqual(self, res, expected)
self.assertAstropyTablesEqual(res, expected)


if __name__ == "__main__":
Expand Down
16 changes: 8 additions & 8 deletions tests/test_cliCmdQueryDatasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@
)
from lsst.daf.butler import script
from lsst.daf.butler.tests import MetricsExample
from lsst.daf.butler.tests.utils import assertAstropyTablesEqual
from lsst.daf.butler.tests.utils import ButlerTestHelper


TESTDIR = os.path.abspath(os.path.dirname(__file__))


class QueryDatasetsTest(unittest.TestCase):
class QueryDatasetsTest(unittest.TestCase, ButlerTestHelper):

configFile = os.path.join(TESTDIR, "config/basic/butler.yaml")
storageClassFactory = StorageClassFactory()
Expand Down Expand Up @@ -165,7 +165,7 @@ def testShowURI(self):
"visit", "URI")),
)

assertAstropyTablesEqual(self, tables, expectedTables)
self.assertAstropyTablesEqual(tables, expectedTables)

def testNoShowURI(self):
"""Test for expected output without show_uri (default is False)."""
Expand All @@ -179,7 +179,7 @@ def testNoShowURI(self):
),
)

assertAstropyTablesEqual(self, tables, expectedTables)
self.assertAstropyTablesEqual(tables, expectedTables)

def testWhere(self):
"""Test using the where clause to reduce the number of rows returned.
Expand All @@ -193,7 +193,7 @@ def testWhere(self):
),
)

assertAstropyTablesEqual(self, tables, expectedTables)
self.assertAstropyTablesEqual(tables, expectedTables)

def testGlobDatasetType(self):
"""Test specifying dataset type."""
Expand Down Expand Up @@ -226,7 +226,7 @@ def testGlobDatasetType(self):
)
)

assertAstropyTablesEqual(self, tables, expectedTables)
self.assertAstropyTablesEqual(tables, expectedTables)

def testFindFirstAndCollections(self):
"""Test the find-first option, and the collections option, since it
Expand Down Expand Up @@ -299,7 +299,7 @@ def testFindFirstAndCollections(self):
"visit", "URI")),
)

assertAstropyTablesEqual(self, tables, expectedTables)
self.assertAstropyTablesEqual(tables, expectedTables)

# Verify that with find first the duplicate dataset is eliminated and
# the more recent dataset is returned.
Expand Down Expand Up @@ -349,7 +349,7 @@ def testFindFirstAndCollections(self):
"visit", "URI")),
)

assertAstropyTablesEqual(self, tables, expectedTables)
self.assertAstropyTablesEqual(tables, expectedTables)


if __name__ == "__main__":
Expand Down

0 comments on commit d4e8cdd

Please sign in to comment.