Skip to content

Commit

Permalink
Add tests and cleanup docs and API.
Browse files Browse the repository at this point in the history
  • Loading branch information
parejkoj committed Nov 16, 2018
1 parent 2d99eb7 commit 4653cd1
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 8 deletions.
33 changes: 25 additions & 8 deletions python/lsst/pipe/tasks/dcrAssembleCoadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,19 @@ class DcrAssembleCoaddConfig(CompareWarpAssembleCoaddConfig):
)
baseGain = pexConfig.Field(
dtype=float,
doc="Relative weight to give the new solution when updating the model."
doc="Relative weight to give the new solution vs. the last solution when updating the model."
"A value of 1.0 gives equal weight to both solutions."
"If ``baseGain`` is set to zero, a conservative gain "
"will be calculated from the number of subfilters",
default=0.,
"If ``baseGain`` is None, a conservative gain "
"will be calculated from the number of subfilters. "
"Small values imply slower convergence of the solution, but can "
"help prevent overshooting and failures in the fit.",
default=None,
)
useProgressiveGain = pexConfig.Field(
dtype=bool,
doc="Use a gain that slowly increases above ``baseGain`` to accelerate convergence?",
doc="Use a gain that slowly increases above ``baseGain`` to accelerate convergence? "
"When calculating the next gain, we use up to 5 previous gains and convergence values."
"Can be set to False to force the model to change at the rate of ``baseGain``. ",
default=True,
)
doAirmassWeight = pexConfig.Field(
Expand Down Expand Up @@ -797,21 +801,31 @@ def calculateGain(self, convergenceList, gainList):
convergenceList : `list` of `float`
The quality of fit metric from each previous iteration.
gainList : `list` of `float`
The gains used in each previous iteration.
The gains used in each previous iteration: appended with new gain
value. Gains are numbers between ``self.config.baseGain`` and 1.
Returns
-------
gain : `float`
Relative weight to give the new solution when updating the model.
A value of 1.0 gives equal weight to both solutions.
Raises
------
ValueError
If ``len(convergenceList) != len(gainList)+1``.
"""
if self.config.baseGain <= 0:
nIter = len(convergenceList)
if nIter != len(gainList) + 1:
raise ValueError("convergenceList must be one element longer than gainList.")

if self.config.baseGain is None:
# If ``baseGain`` is not set, calculate it from the number of DCR subfilters
# The more subfilters being modeled, the lower the gain should be.
baseGain = 1./(self.config.dcrNumSubfilters - 1)
else:
baseGain = self.config.baseGain
nIter = len(convergenceList)

if self.config.useProgressiveGain and nIter > 2:
# To calculate the best gain to use, compare the past gains that have been used
# with the resulting convergences to estimate the best gain to use.
Expand All @@ -822,6 +836,9 @@ def calculateGain(self, convergenceList, gainList):
# weighted by the gains used in each previous iteration.
estFinalConv = [((1 + gainList[i])*convergenceList[i + 1] - convergenceList[i])/gainList[i]
for i in range(nIter - 1)]
# WORDS about why negative is bad
estFinalConv = np.array(estFinalConv)
estFinalConv[estFinalConv < 0] = 0
# Because the estimate may slowly change over time, only use the most recent measurements.
estFinalConv = np.median(estFinalConv[max(nIter - 5, 0):])
lastGain = gainList[nIter - 2]
Expand Down
134 changes: 134 additions & 0 deletions tests/test_dcrAssembleCoadd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# This file is part of pipe_tasks.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import unittest
import unittest.mock

import lsst.utils.tests

import lsst.afw.image
import lsst.daf.persistence
from lsst.pipe.tasks.dcrAssembleCoadd import DcrAssembleCoaddTask, DcrAssembleCoaddConfig


class DcrAssembleCoaddCalculateGainTestCase(lsst.utils.tests.TestCase):
"""Tests of dcrAssembleCoaddTask.calculateGain()."""
def setUp(self):
self.baseGain = 0.5
self.gainList = [self.baseGain, self.baseGain]
self.convergenceList = [0.2]
# Calculate the convergence we would expect if the model was converging perfectly,
# so that the improvement is limited only by our conservative gain.
for i in range(2):
self.convergenceList.append(self.convergenceList[i]/(self.baseGain + 1))
self.nextGain = (1 + self.baseGain) / 2

self.config = DcrAssembleCoaddConfig()
self.task = DcrAssembleCoaddTask(self.config)

def testUnbalancedLists(self):
gainList = [1, 2, 3, 4]
convergenceList = [1, 2]
with self.assertRaises(ValueError):
self.task.calculateGain(convergenceList, gainList)

def testNoProgressiveGain(self):
self.config.useProgressiveGain = False
self.config.baseGain = self.baseGain
expectGain = self.baseGain
expectGainList = self.gainList + [expectGain]
result = self.task.calculateGain(self.convergenceList, self.gainList)
self.assertEqual(result, expectGain)
self.assertEqual(self.gainList, expectGainList)

def testBaseGainNone(self):
"""If baseGain is None, gain is calculated from the default values."""
self.config.useProgressiveGain = False
expectGain = 1 / (self.config.dcrNumSubfilters - 1)
expectGainList = self.gainList + [expectGain]
result = self.task.calculateGain(self.convergenceList, self.gainList)
self.assertEqual(result, expectGain)
self.assertEqual(self.gainList, expectGainList)

def testProgressiveFirstStep(self):
"""The first and second steps always return baseGain."""
convergenceList = self.convergenceList[:1]
gainList = []
self.config.baseGain = self.baseGain
expectGain = self.baseGain
expectGainList = [expectGain]
result = self.task.calculateGain(convergenceList, gainList)
self.assertEqual(result, expectGain)
self.assertEqual(gainList, expectGainList)

def testProgressiveSecondStep(self):
"""The first and second steps always return baseGain."""
convergenceList = self.convergenceList[:2]
gainList = self.gainList[:1]
self.config.baseGain = self.baseGain
expectGain = self.baseGain
expectGainList = gainList + [expectGain]
result = self.task.calculateGain(convergenceList, gainList)
self.assertEqual(result, expectGain)
self.assertEqual(gainList, expectGainList)

def testProgressiveGain(self):
"""Test that gain follows the "perfect" situation defined in setUp."""
self.config.baseGain = self.baseGain
expectGain = self.nextGain
expectGainList = self.gainList + [expectGain]
result = self.task.calculateGain(self.convergenceList, self.gainList)
self.assertFloatsAlmostEqual(result, expectGain)
self.assertEqual(self.gainList, expectGainList)

def testProgressiveGainBadFit(self):
"""Test that gain is reduced if the predicted convergence does not
match the measured convergence (in this case, converging too quickly).
"""
wrongGain = 1.0
gainList = [self.baseGain, self.baseGain]
convergenceList = [0.2]
for i in range(2):
convergenceList.append(convergenceList[i]/(wrongGain + 1))
# The below math is a simplified version of the full algorithm,
# assuming the predicted convergence is zero.
# Note that in this case, nextGain is smaller than wrongGain.
nextGain = (self.baseGain + (1 + self.baseGain) / (1 + wrongGain)) / 2

self.config.baseGain = self.baseGain
expectGain = nextGain
expectGainList = self.gainList + [expectGain]
result = self.task.calculateGain(convergenceList, gainList)
self.assertFloatsAlmostEqual(result, nextGain)
self.assertEqual(gainList, expectGainList)


def setup_module(module):
lsst.utils.tests.init()


class MatchMemoryTestCase(lsst.utils.tests.MemoryTestCase):
pass


if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()

0 comments on commit 4653cd1

Please sign in to comment.