Skip to content

Commit

Permalink
JointcalStatistics bugfix re: visit tracking, add lsstSim pa1 test
Browse files Browse the repository at this point in the history
Fixed a serious bug in how I was tracking visits in JointcalStatistics, as well
as a few less serious but not-trivial related bugs. In the process, cleaned up
the args that various methods take (removing ones that should just be computed
and handled internally) and renamed args in the plotting functions.

Added pa1 test to test_jointcal_lsstSim.py, using a value just above the current
computed value.

Several other docstring and comment cleanups.

Tweaked default plot_jointcal directory and mkdir if it's missing.
  • Loading branch information
parejkoj committed Jan 13, 2017
1 parent a5119e6 commit 1bdf37d
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 70 deletions.
10 changes: 6 additions & 4 deletions bin.src/plot_jointcal_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,19 @@ def main():
help="Radius (degrees) of sources to load from reference catalog.")
parser.add_argument("-i", "--interactive", action="store_true",
help="Use interactive matplotlib backend and set ion(), in addition to saving files.")
parser.add_argument("-o", "--outdir", default=".plots",
parser.add_argument("-o", "--outdir", default="plots",
help="output directory for plots (default: $(default)s)")
parser.add_argument("-v", "--verbose", action="store_true",
help="Print extra things during calculations.")
args = parser.parse_args()

if not os.path.isdir(args.outdir):
os.mkdir(args.outdir)

butler = lsst.daf.persistence.Butler(inputs=args.repo)
dataIds = get_valid_dataIds(butler)

data_refs = [butler.dataRef('wcs', dataId=dataId) for dataId in dataIds]
visits = [data_ref.dataId['visit'] for data_ref in data_refs]
old_wcs_list = get_old_wcs_list(data_refs)

os.environ['ASTROMETRY_NET_DATA_DIR'] = args.refcat
Expand All @@ -89,10 +91,10 @@ def main():
reference = prep_reference_loader(center, args.radius*degrees)

jointcalStatistics = utils.JointcalStatistics(verbose=args.verbose)
jointcalStatistics.compute_rms(data_refs, visits, reference)
jointcalStatistics.compute_rms(data_refs, reference)

name = os.path.basename(os.path.normpath(args.repo))
jointcalStatistics.make_plots(data_refs, visits, old_wcs_list, name=name,
jointcalStatistics.make_plots(data_refs, old_wcs_list, name=name,
interactive=args.interactive, outdir=args.outdir)


Expand Down
95 changes: 50 additions & 45 deletions python/lsst/jointcal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,39 +51,47 @@ def __init__(self, match_radius=0.1*arcseconds, flux_limit=100.0, verbose=False)
self.verbose = verbose
self.log = lsst.log.Log.getLogger('JointcalStatistics')

def compute_rms(self, data_refs, visit_list, reference):
def compute_rms(self, data_refs, reference):
"""
Match all data_refs to compute the RMS, for all detections above self.flux_limit.
Parameters
----------
data_refs : list of lsst.daf.persistence.butlerSubset.ButlerDataRef
A list of data refs to do the calculations between.
visit_list : list of visit id (usually int)
list of visit identifiers to do the catalog merge on.
reference : lsst reference catalog
reference catalog to do absolute matching against.
Return
------
astropy.Quantity
New relative RMS of the matched sources.
astropy.Quantity
New absolute RMS of matched sources.
namedtuple:
astropy.Quantity
New relative RMS of the matched sources.
astropy.Quantity
New absolute RMS of matched sources.
float
post-jointcal photometric repeatability (PA1 from the SRD).
"""

# DECAM doesn't have "filter" in its registry, so we have to get filter names from VisitInfo.
self.filters = [ref.get('calexp').getInfo().getFilter().getName() for ref in data_refs]
visits_per_dataRef = [ref.dataId['visit'] for ref in data_refs]
self.visits_per_dataRef = [ref.dataId['visit'] for ref in data_refs]

def compute(catalogs, calibs):
"""Compute the relative and absolute matches in distance and flux."""
visit_catalogs = self._make_visit_catalogs(catalogs, visits_per_dataRef, visit_list)
catalogs = list(visit_catalogs.values())
visit_catalogs = self._make_visit_catalogs(catalogs, self.visits_per_dataRef)
catalogs = [visit_catalogs[x] for x in self.visits_per_dataRef]
# use the first catalog as the relative reference catalog
# NOTE: The "first" catalog depends on the original ordering of the data_refs.
# NOTE: Thus, because I'm doing a many-1 match in _make_match_dict,
# the number of matches (and thus the details of the match statistics)
# will change if the data_refs are ordered differently.
# All the more reason to use a proper n-way matcher here...
refcat = catalogs[0]
refcalib = calibs[0]
dist_rel, flux_rel, ref_flux_rel, source_rel = self._make_match_dict(refcat, catalogs[1:], calibs,
dist_rel, flux_rel, ref_flux_rel, source_rel = self._make_match_dict(refcat,
catalogs[1:],
calibs[1:],
refcalib=refcalib)
dist_abs, flux_abs, ref_flux_abs, source_abs = self._make_match_dict(reference, catalogs, calibs)
dist = MatchDict(dist_rel, dist_abs)
Expand All @@ -97,14 +105,12 @@ def compute(catalogs, calibs):
self.old_dist, self.old_flux, self.old_ref_flux, self.old_source = compute(old_cats, old_calibs)

# Update coordinates with the new wcs, and get the new Calibs.
new_cats = []
new_calibs = []
for ref in data_refs:
new_cats.append(ref.get('src'))
wcs = ref.get('wcs')
new_calibs.append(wcs.getCalib())
new_cats = [ref.get('src') for ref in data_refs]
new_wcss = [ref.get('wcs') for ref in data_refs]
new_calibs = [wcs.getCalib() for wcs in new_wcss]
for wcs, cat in zip(new_wcss, new_cats):
# update in-place the object coordinates based on the new wcs
lsst.afw.table.utils.updateSourceCoords(wcs.getWcs(), new_cats[-1])
lsst.afw.table.utils.updateSourceCoords(wcs.getWcs(), cat)

self.new_dist, self.new_flux, self.new_ref_flux, self.new_source = compute(new_cats, new_calibs)

Expand All @@ -124,9 +130,11 @@ def rms_total(data):
self.old_dist_total = MatchDict(*(tuple(map(rms_total, self.old_dist))*u.radian).to(u.arcsecond))
self.new_dist_total = MatchDict(*(tuple(map(rms_total, self.new_dist))*u.radian).to(u.arcsecond))

return self.new_dist_total.relative, self.new_dist_total.absolute
Rms_result = collections.namedtuple("rms_result", ["dist_relative", "dist_absolute", "pa1"])
result = Rms_result(self.new_dist_total.relative, self.new_dist_total.absolute, self.new_PA1)
return result

def make_plots(self, data_refs, visit_catalogs, old_wcs_list,
def make_plots(self, data_refs, old_wcs_list,
name='', interactive=False, per_ccd_plot=False, outdir='.plots'):
"""
Make plots of various quantites to help with debugging.
Expand All @@ -136,8 +144,6 @@ def make_plots(self, data_refs, visit_catalogs, old_wcs_list,
----------
data_refs : list of lsst.daf.persistence.butlerSubset.ButlerDataRef
A list of data refs to do the calculations between.
visit_catalogs : list of lsst.afw.table.SourceCatalog
visit source catalogs (values() produced by _make_visit_catalogs) to cross-match.
old_wcs_list : list of lsst.afw.image.wcs.Wcs
A list of the old (pre-jointcal) WCSs, one-to-one corresponding to data_refs.
name : str
Expand All @@ -164,7 +170,7 @@ def make_plots(self, data_refs, visit_catalogs, old_wcs_list,
plt.ion()

plot_flux_distributions(plt, self.old_mag, self.new_mag, self.old_jitter, self.new_jitter,
self.faint, self.bright, self.old_PA1, self.new_PA1,)
self.faint, self.bright, self.old_PA1, self.new_PA1, name=name, outdir=outdir)

def rms_per_source(data):
"""Each element of data must already be the "delta" of whatever measurement."""
Expand All @@ -183,7 +189,7 @@ def rms_per_source(data):
self.old_dist_total.relative, self.old_dist_total.absolute,
self.new_dist_total.relative, self.new_dist_total.absolute, name, outdir=outdir)

plot_all_wcs_deltas(plt, data_refs, visit_catalogs, old_wcs_list, name,
plot_all_wcs_deltas(plt, data_refs, self.visits_per_dataRef, old_wcs_list, name,
outdir=outdir, per_ccd_plot=per_ccd_plot)

if interactive:
Expand Down Expand Up @@ -253,11 +259,11 @@ def _make_match_dict(self, reference, visit_catalogs, calibs, refcalib=None):
Returns
-------
distances: dict
dict of sourceID: list(separation distances for that source)
dict of sourceID: array(separation distances for that source)
fluxes: dict
dict of sourceID: list(fluxes (Jy) for that source)
dict of sourceID: array(fluxes (Jy) for that source)
ref_fluxes: dict
dict of sourceID: list(fluxes (Jy) for the reference object)
dict of sourceID: flux (Jy) of the reference object
sources: dict
dict of sourceID: list(each SourceRecord that was position-matched to this sourceID)
"""
Expand Down Expand Up @@ -297,7 +303,7 @@ def _make_match_dict(self, reference, visit_catalogs, calibs, refcalib=None):

return distances, fluxes, ref_fluxes, sources

def _make_visit_catalogs(self, catalogs, visits, visit_list):
def _make_visit_catalogs(self, catalogs, visits):
"""
Merge all catalogs from the each visit.
NOTE: creating this structure is somewhat slow, and will be unnecessary
Expand All @@ -308,16 +314,14 @@ def _make_visit_catalogs(self, catalogs, visits, visit_list):
catalogs : list of lsst.afw.table.SourceCatalog
Catalogs to combine into per-visit catalogs.
visits : list of visit id (usually int)
list of visit identifiers, one-to-one correspondant with catalogs.
visit_list : list of visit id (usually int)
list of visit identifiers to do the catalog merge on (a proper subset of visits).
list of visit identifiers, one-to-one correspondent with catalogs.
Returns
-------
dict
dict of visit: catalog of all sources from all CCDs of that visit.
"""
visit_dict = {v: lsst.afw.table.SourceCatalog(catalogs[0].schema) for v in visit_list}
visit_dict = {v: lsst.afw.table.SourceCatalog(catalogs[0].schema) for v in visits}
for v, cat in zip(visits, catalogs):
visit_dict[v].extend(cat)
# We want catalog contiguity to do object selection later.
Expand All @@ -338,6 +342,7 @@ def plot_flux_distributions(plt, old_mag, new_mag, old_jitter, new_jitter,

old_color = 'blue'
new_color = 'red'
plt.figure()
plt.plot(old_mag, old_jitter, '.', color=old_color, label='old')
plt.plot(new_mag, new_jitter, '.', color=new_color, label='new')
plt.axvline(faint, ls=':', color=old_color)
Expand All @@ -361,12 +366,12 @@ def plot_flux_distributions(plt, old_mag, new_mag, old_jitter, new_jitter,
plt.savefig(filename.format(name))


def plot_all_wcs_deltas(plt, data_refs, visit_catalogs, old_wcs_list, name,
def plot_all_wcs_deltas(plt, data_refs, visits, old_wcs_list, name,
per_ccd_plot=False, outdir='.plots'):
"""Various plots of the difference between old and new Wcs."""

plot_wcs_magnitude(plt, data_refs, visit_catalogs, old_wcs_list, name, outdir=outdir)
plot_all_wcs_quivers(plt, data_refs, visit_catalogs, old_wcs_list, name, outdir=outdir)
plot_wcs_magnitude(plt, data_refs, visits, old_wcs_list, name, outdir=outdir)
plot_all_wcs_quivers(plt, data_refs, visits, old_wcs_list, name, outdir=outdir)

if per_ccd_plot:
for i, ref in enumerate(data_refs):
Expand Down Expand Up @@ -398,15 +403,15 @@ def wcs_convert(xv, yv, wcs):
return xout, yout


def plot_all_wcs_quivers(plt, data_refs, visit_catalogs, old_wcs_list, name, outdir='.plots'):
def plot_all_wcs_quivers(plt, data_refs, visits, old_wcs_list, name, outdir='.plots'):
"""Make quiver plots of the WCS deltas for each CCD in each visit."""

for cat in visit_catalogs:
for visit in visits:
fig = plt.figure()
# fig.set_tight_layout(True)
ax = fig.add_subplot(111)
for old_wcs, ref in zip(old_wcs_list, data_refs):
if ref.dataId['visit'] != cat:
if ref.dataId['visit'] != visit:
continue
md = ref.get('calexp_md')
Q = plot_wcs_quivers(ax, old_wcs, ref.get('wcs').getWcs(),
Expand All @@ -417,9 +422,9 @@ def plot_all_wcs_quivers(plt, data_refs, visit_catalogs, old_wcs_list, name, out
ax.quiverkey(Q, 0.9, 0.95, length, '0.1 arcsec', coordinates='figure', labelpos='W')
plt.xlabel('RA')
plt.ylabel('Dec')
plt.title('visit: {}'.format(cat))
plt.title('visit: {}'.format(visit))
filename = os.path.join(outdir, '{}-{}-quivers.pdf')
plt.savefig(filename.format(name, cat))
plt.savefig(filename.format(name, visit))


def plot_wcs_quivers(ax, wcs1, wcs2, dim):
Expand All @@ -431,9 +436,9 @@ def plot_wcs_quivers(ax, wcs1, wcs2, dim):
return ax.quiver(x1, y1, uu, vv, units='x', pivot='tail', scale=1e-3, width=1e-5)


def plot_wcs_magnitude(plt, data_refs, visit_catalogs, old_wcs_list, name, outdir='.plots'):
def plot_wcs_magnitude(plt, data_refs, visits, old_wcs_list, name, outdir='.plots'):
"""Plot the magnitude of the WCS change between old and new visits as a heat map."""
for cat in visit_catalogs:
for visit in visits:
fig = plt.figure()
fig.set_tight_layout(True)
ax = fig.add_subplot(111)
Expand All @@ -443,7 +448,7 @@ def plot_wcs_magnitude(plt, data_refs, visit_catalogs, old_wcs_list, name, outdi
xmax = -np.inf
ymax = -np.inf
for old_wcs, ref in zip(old_wcs_list, data_refs):
if ref.dataId['visit'] != cat:
if ref.dataId['visit'] != visit:
continue
md = ref.get('calexp_md')
x1, y1, x2, y2 = make_xy_wcs_grid((md.get('NAXIS1'), md.get('NAXIS2')),
Expand All @@ -469,9 +474,9 @@ def plot_wcs_magnitude(plt, data_refs, visit_catalogs, old_wcs_list, name, outdi
plt.ylim(ymin, ymax)
plt.xlabel('RA')
plt.ylabel('Dec')
plt.title('visit: {}'.format(cat))
plt.title('visit: {}'.format(visit))
filename = os.path.join(outdir, '{}-{}-heatmap.pdf')
plt.savefig(filename.format(name, cat))
plt.savefig(filename.format(name, visit))


def plot_wcs(plt, wcs1, wcs2, dim, center=(0, 0), name="", outdir='.plots'):
Expand Down
26 changes: 15 additions & 11 deletions tests/jointcalTestBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from builtins import object

import os
import inspect

import lsst.afw.geom
from lsst.meas.astrom import LoadAstrometryNetObjectsTask, LoadAstrometryNetObjectsConfig
Expand Down Expand Up @@ -74,12 +75,15 @@ def _prep_reference_loader(self, center, radius):
# Make a copy of the reference catalog for in-memory contiguity.
self.reference = refLoader.loadSkyCircle(center, radius, filterName='r').refCat.copy()

def _testJointCalTask(self, nCatalogs, relative_error, absolute_error):
"""Test parseAndRun for jointcal on nCatalogs, requiring less than some error (arcsec)."""
def _testJointCalTask(self, nCatalogs, dist_rms_relative, dist_rms_absolute, pa1):
"""
Test parseAndRun for jointcal on nCatalogs.
Checks relative and absolute astrometric error (arcsec) and photometric
repeatability (PA1 from the SRD).
"""

visit_list = self.all_visits[:nCatalogs]
visits = '^'.join(str(v) for v in visit_list)
import inspect
visits = '^'.join(str(v) for v in self.all_visits[:nCatalogs])
# the calling method is one step back on the stack: use it to specify the output repo.
caller = inspect.stack()[1][3] # NOTE: could be inspect.stack()[1].function in py3.5
output_dir = os.path.join('.test', self.__class__.__name__, caller)
Expand All @@ -90,13 +94,13 @@ def _testJointCalTask(self, nCatalogs, relative_error, absolute_error):
args.extend(self.other_args)
result = jointcal.JointcalTask.parseAndRun(args=args, doReturnResults=True)
self.assertNotEqual(result.resultList, [], 'resultList should not be empty')
self.dataRefs = result.resultList[0].result.dataRefs
data_refs = result.resultList[0].result.dataRefs
oldWcsList = result.resultList[0].result.oldWcsList

rms_rel, rms_abs = self.jointcalStatistics.compute_rms(self.dataRefs, visit_list, self.reference)
self.assertLess(rms_rel, relative_error)
self.assertLess(rms_abs, absolute_error)
rms_result = self.jointcalStatistics.compute_rms(data_refs, self.reference)
self.assertLess(rms_result.dist_relative, dist_rms_relative)
self.assertLess(rms_result.dist_absolute, dist_rms_absolute)
self.assertLess(rms_result.pa1, pa1)

if self.do_plot:
name = self.id.strip('__main__.')
self.jointcalStatistics.make_plots(self.dataRefs, self.visitCatalogs, oldWcsList, name=name)
self.jointcalStatistics.make_plots(data_refs, oldWcsList, name=caller)
24 changes: 14 additions & 10 deletions tests/test_jointcal_lsstSim.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# than the single-epoch astrometry (about 0.040").
# This value was empirically determined from the first run of jointcal on
# this data, and will likely vary from survey to survey.
absolute_error = 42e-3*u.arcsecond
dist_rms_absolute = 42e-3*u.arcsecond


# for MemoryTestCase
Expand Down Expand Up @@ -54,32 +54,36 @@ def setUp(self):
@unittest.skipIf(data_dir is None, "testdata_jointcal not setup")
@unittest.skip('jointcal currently fails if only given one catalog!')
def testJointCalTask_1_visits(self):
self._testJointCalTask(2, 0, absolute_error)
self._testJointCalTask(2, 0, dist_rms_absolute)

@unittest.skipIf(data_dir is None, "testdata_jointcal not setup")
def testJointCalTask_2_visits(self):
# NOTE: The relative RMS limits were empirically determined from the
# first run of jointcal on this data. We should always do better than
# this in the future!
relative_error = 9.7e-3*u.arcsecond
self._testJointCalTask(2, relative_error, absolute_error)
dist_rms_relative = 9.7e-3*u.arcsecond
pa1 = 2.64e-3
self._testJointCalTask(2, dist_rms_relative, dist_rms_absolute, pa1)

@unittest.skipIf(data_dir is None, "testdata_jointcal not setup")
@unittest.skip('Keeping this around for diagnostics on the behavior with n catalogs.')
def testJointCalTask_4_visits(self):
relative_error = 8.2e-3*u.arcsecond
self._testJointCalTask(4, relative_error, absolute_error)
dist_rms_relative = 8.2e-3*u.arcsecond
pa1 = 2.64e-3
self._testJointCalTask(4, dist_rms_relative, dist_rms_absolute, pa1)

@unittest.skipIf(data_dir is None, "testdata_jointcal not setup")
@unittest.skip('Keeping this around for diagnostics on the behavior with n catalogs.')
def testJointCalTask_7_visits(self):
relative_error = 8.1e-3*u.arcsecond
self._testJointCalTask(7, relative_error, absolute_error)
dist_rms_relative = 8.1e-3*u.arcsecond
pa1 = 2.64e-3
self._testJointCalTask(7, dist_rms_relative, dist_rms_absolute, pa1)

@unittest.skipIf(data_dir is None, "testdata_jointcal not setup")
def testJointCalTask_10_visits(self):
relative_error = 7.9e-3*u.arcsecond
self._testJointCalTask(10, relative_error, absolute_error)
dist_rms_relative = 7.9e-3*u.arcsecond
pa1 = 2.64e-3
self._testJointCalTask(10, dist_rms_relative, dist_rms_absolute, pa1)


# TODO: the memory test cases currently fail in jointcal. Filed as DM-6626.
Expand Down

0 comments on commit 1bdf37d

Please sign in to comment.