Skip to content

Commit

Permalink
Add unit test and fix bugs that illustrated.
Browse files Browse the repository at this point in the history
- Update docstrings and comments to be clearer.
- Fix faulty valid/badAmp logic.
- Add __eq__ handler.
  • Loading branch information
czwa committed Mar 31, 2021
1 parent 0747caf commit f2b5c2e
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 24 deletions.
45 changes: 38 additions & 7 deletions python/lsst/ip/isr/brighterFatterKernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class BrighterFatterKernel(IsrCalib):
makeDetectorKernelFromAmpwiseKernels is a method to generate the
kernel for a detector, constructed by averaging together the
ampwise kernels in the detector.
ampwise kernels in the detector. The existing application code is
only defined for kernels with level == 'DETECTOR', so this method
is used if the supplied kernel was built with level == 'AMP'.
Parameters
----------
Expand All @@ -66,7 +68,7 @@ def __init__(self, camera=None, level=None, **kwargs):
self.variances = dict()
self.rawXcorrs = dict()
self.badAmps = list()
self.shape = (8, 8)
self.shape = (17, 17)
self.gain = dict()
self.noise = dict()

Expand All @@ -86,6 +88,24 @@ def __init__(self, camera=None, level=None, **kwargs):
'badAmps', 'gain', 'noise', 'meanXcorrs', 'valid',
'ampKernels', 'detKernels'])

def __eq__(self, other):
"""Calibration equivalence
"""
if not isinstance(other, self.__class__):
return False

for attr in self._requiredAttributes:
attrSelf = getattr(self, attr)
attrOther = getattr(other, attr)
if isinstance(attrSelf, dict) and isinstance(attrOther, dict):
for ampName in attrSelf:
if not np.allclose(attrSelf[ampName], attrOther[ampName], equal_nan=True):
return False
else:
if attrSelf != attrOther:
return False
return True

def updateMetadata(self, setDate=False, **kwargs):
"""Update calibration metadata.
Expand Down Expand Up @@ -120,10 +140,19 @@ def initFromCamera(self, camera, detectorId=None):
-------
calib : `lsst.ip.isr.BrighterFatterKernel`
The initialized calibration.
Raises
------
RuntimeError :
Raised if no detectorId is supplied for a calibration with
level='AMP'.
"""
self._instrument = camera.getName()

if self.level == 'AMP':
if detectorId is None:
raise RuntimeError("A detectorId must be supplied if level='AMP'.")

detector = camera[detectorId]
self._detectorId = detectorId
self._detectorName = detector.getName()
Expand Down Expand Up @@ -200,18 +229,19 @@ def fromDict(cls, dictionary):
calib.level = dictionary['metadata'].get('LEVEL', 'AMP')
calib.shape = (dictionary['metadata'].get('KERNEL_DX', 0),
dictionary['metadata'].get('KERNEL_DY', 0))
calib.badAmps = dictionary['badAmps']

calib.means = {amp: np.array(dictionary['means'][amp]) for amp in dictionary['means']}
calib.variances = {amp: np.array(dictionary['variances'][amp]) for amp in dictionary['variances']}

# Lengths for reshape:
_, smallShape, nObs = calib.getLengths()
_, smallLength, nObs = calib.getLengths()
smallShapeSide = int(np.sqrt(smallLength))

calib.rawXcorrs = {amp: np.array(dictionary['rawXcorrs'][amp]).reshape((nObs,
smallShape[0],
smallShape[1]))
smallShapeSide,
smallShapeSide))
for amp in dictionary['rawXcorrs']}

calib.gain = dictionary['gain']
calib.noise = dictionary['noise']

Expand All @@ -220,6 +250,7 @@ def fromDict(cls, dictionary):
calib.ampKernels = {amp: np.array(dictionary['ampKernels'][amp]).reshape(calib.shape)
for amp in dictionary['ampKernels']}
calib.valid = {amp: bool(value) for amp, value in dictionary['valid'].items()}
calib.badAmps = [amp for amp, valid in dictionary['valid'].items() if valid is False]

calib.detKernels = {det: np.array(dictionary['detKernels'][det]).reshape(calib.shape)
for det in dictionary['detKernels']}
Expand Down Expand Up @@ -311,7 +342,7 @@ def fromTable(cls, tableList):
inDict['ampKernels'] = {amp: kernel for amp, kernel in zip(amps, ampKernels)}
inDict['valid'] = {amp: bool(valid) for amp, valid in zip(amps, validList)}

inDict['badAmps'] = [amp for amp, valid in inDict['valid'].items() if valid is True]
inDict['badAmps'] = [amp for amp, valid in inDict['valid'].items() if valid is False]

if len(tableList) > 1:
detTable = tableList[1]
Expand Down
63 changes: 46 additions & 17 deletions tests/test_brighterFatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,49 @@
#

import unittest
import pickle
import os
import numpy as np

import lsst.utils.tests
import lsst.afw.cameraGeom as cameraGeom
import lsst.afw.image as afwImage
import lsst.ip.isr.isrFunctions as isrFunctions
from lsst.ip.isr import BrighterFatterKernel


class BrighterFatterTestCases(lsst.utils.tests.TestCase):

def setUp(self):
self.filename = "bf_kernel.pkl"
kernel = afwImage.ImageF(17, 17)
kernel[9, 9, afwImage.LOCAL] = 1
kernelPickleString = kernel.getArray().dumps()
# kernel.getArray().dump(self.filename) triggers an "unclosed file" warning with numpy 1.13.1
with open(self.filename, 'wb') as f:
f.write(kernelPickleString)

def tearDown(self):
os.unlink(self.filename)

def testBrighterFatterInterface(self):
"""Set up a no-op BFK dataset
"""
cameraBuilder = cameraGeom.Camera.Builder('fake camera')
detectorWrapper = cameraGeom.testUtils.DetectorWrapper(numAmps=4, cameraBuilder=cameraBuilder)
self.detector = detectorWrapper.detector
camera = cameraBuilder.finish()

self.bfk = BrighterFatterKernel(level='AMP', camera=camera, detectorId=1)
self.bfk.shape = (17, 17)
self.bfk.badAmps = ['amp 3']

covar = np.zeros((8, 8))
covar[0, 0] = 1.0

kernel = np.zeros(self.bfk.shape)
kernel[8, 8] = 1.0

for amp in self.detector:
ampName = amp.getName()
self.bfk.means[ampName] = [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000]
self.bfk.variances[ampName] = np.array(self.bfk.means[ampName], dtype=float)
self.bfk.rawXcorrs[ampName] = [covar for _ in self.bfk.means[ampName]]
self.bfk.gain[ampName] = 1.0
self.bfk.noise[ampName] = 5.0

self.bfk.meanXcorrs[ampName] = kernel
self.bfk.valid[ampName] = (ampName != 'amp 3')

self.bfk.ampKernels[ampName] = kernel

def test_BrighterFatterInterface(self):
"""Test brighter fatter correction interface using a delta function kernel on a flat image"""

image = afwImage.ImageF(100, 100)
Expand All @@ -53,12 +73,21 @@ def testBrighterFatterInterface(self):
mi = afwImage.makeMaskedImage(image)
exp = afwImage.makeExposure(mi)

with open(self.filename, 'rb') as f:
bfKernel = pickle.load(f)
self.bfk.makeDetectorKernelFromAmpwiseKernels(self.detector.getName())
kernelToUse = self.bfk.detKernels[self.detector.getName()]

isrFunctions.brighterFatterCorrection(exp, bfKernel, 5, 100, False)
isrFunctions.brighterFatterCorrection(exp, kernelToUse, 5, 100, False)
self.assertImagesEqual(ref_image, image)

def test_BrighterFatterIO(self):
dictionary = self.bfk.toDict()
newBfk = BrighterFatterKernel().fromDict(dictionary)
self.assertEqual(self.bfk, newBfk)

tables = self.bfk.toTable()
newBfk = BrighterFatterKernel().fromTable(tables)
self.assertEqual(self.bfk, newBfk)


class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass
Expand Down

0 comments on commit f2b5c2e

Please sign in to comment.