diff --git a/doc/versionHistory.rst b/doc/versionHistory.rst index 7ba087a1..1fb38eda 100644 --- a/doc/versionHistory.rst +++ b/doc/versionHistory.rst @@ -6,6 +6,15 @@ Version History ################## +.. _lsst.ts.wep-9.5.5: + +------------- +9.5.5 +------------- + +* Correct indices used to calculate Zernike average. +* Update tests to discern whether flags and mean use the same indices. + .. _lsst.ts.wep-9.5.4: ------------- diff --git a/python/lsst/ts/wep/task/combineZernikesSigmaClipTask.py b/python/lsst/ts/wep/task/combineZernikesSigmaClipTask.py index c3d857f6..8a034fe3 100644 --- a/python/lsst/ts/wep/task/combineZernikesSigmaClipTask.py +++ b/python/lsst/ts/wep/task/combineZernikesSigmaClipTask.py @@ -77,10 +77,29 @@ def combineZernikes(self, zernikeArray): # Create a binary flag array that indicates # donuts have outlier values. This array is 1 if # it has any outlier values. - binaryFlagArray = np.any( - np.isnan(sigArray[:, : self.maxZernClip]), axis=1 - ).astype(int) + # If all available donuts have a clipped value in the + # first maxZernClip coefficients then reduce maxZernClip by 1 + # until we get one that passes. + numRejected = len(sigArray) + effMaxZernClip = self.maxZernClip + 1 + + while numRejected == len(sigArray): + effMaxZernClip -= 1 + binaryFlagArray = np.any( + np.isnan(sigArray[:, :effMaxZernClip]), axis=1 + ).astype(int) + numRejected = np.sum(binaryFlagArray) # Identify which rows to use when calculating final mean - keepIdx = ~np.any(np.isnan(sigArray), axis=1) + keepIdx = ~np.array(binaryFlagArray, dtype=bool) + + self.log.info( + f"MaxZernClip config: {self.maxZernClip}. MaxZernClip used: {effMaxZernClip}." + ) + if effMaxZernClip < self.maxZernClip: + self.log.warning( + f"EffMaxZernClip ({effMaxZernClip}) was less than MaxZernClip config ({self.maxZernClip})." + ) + self.metadata["maxZernClip"] = self.maxZernClip + self.metadata["effMaxZernClip"] = effMaxZernClip return np.mean(zernikeArray[keepIdx], axis=0), binaryFlagArray diff --git a/tests/task/test_combineZernikesSigmaClipTask.py b/tests/task/test_combineZernikesSigmaClipTask.py index 8e2ee18f..29777620 100644 --- a/tests/task/test_combineZernikesSigmaClipTask.py +++ b/tests/task/test_combineZernikesSigmaClipTask.py @@ -71,7 +71,7 @@ def testCombineZernikes(self): # Test that zernikes higher than maxZernClip don't remove # a row from the final averaging zernikeArray[0, 3:] += 100.0 - zernikeArray[51, 3:] -= 100.0 + zernikeArray[49, 3:] -= 100.0 # Revert the change in the 100th row from previous test trueFlags[100] = 1 combinedZernikes, testFlags = self.task.combineZernikes(zernikeArray) @@ -80,13 +80,33 @@ def testCombineZernikes(self): self.assertTrue(isinstance(testFlags[0], numbers.Integral)) # Test that changing the maxZernClip parameter does change - # if a row is removed from the final result + # whether a row is removed from the final result + zernikeArray[50, 3:] += 100.0 + zernikeArray[51, 3:] -= 100.0 self.config.maxZernClip = 5 self.task = CombineZernikesSigmaClipTask(config=self.config) combinedZernikes, testFlags = self.task.combineZernikes(zernikeArray) np.testing.assert_array_equal(np.ones(10) * 2.0, combinedZernikes) trueFlags[0] = 1 - trueFlags[51] = 1 + trueFlags[49:52] = 1 + np.testing.assert_array_equal(trueFlags, testFlags) + self.assertTrue(isinstance(testFlags[0], numbers.Integral)) + + def testCombineZernikesEffectiveMaxZernClip(self): + testWhileZernikeArray = np.ones((3, 10)) + testWhileZernikeArray[0, 4] = 3 + testWhileZernikeArray[1, 3] = 3 + testWhileZernikeArray[2, 2] = 3 + + # Test that changing the maxZernClip parameter does change + # whether a row is removed from the final result + self.config.maxZernClip = 6 + self.task = CombineZernikesSigmaClipTask(config=self.config) + combinedZernikes, testFlags = self.task.combineZernikes(testWhileZernikeArray) + np.testing.assert_array_equal(testWhileZernikeArray[0], combinedZernikes) + self.assertEqual(self.task.metadata["maxZernClip"], 6) + self.assertEqual(self.task.metadata["effMaxZernClip"], 4) + trueFlags = np.array([0, 1, 1]) np.testing.assert_array_equal(trueFlags, testFlags) self.assertTrue(isinstance(testFlags[0], numbers.Integral))