Skip to content

Commit

Permalink
Add removal of degnerate templates
Browse files Browse the repository at this point in the history
Compute the dot product between all pairs of templates and
remove one of the objects if the dot product is larger than
threshold.
  • Loading branch information
rearmstr committed Oct 28, 2016
1 parent 43e5640 commit 4e9d871
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 28 deletions.
143 changes: 117 additions & 26 deletions python/lsst/meas/deblender/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self):
self.tinyFootprint = False
self.noValidPixels = False
self.deblendedAsPsf = False
self.degenerate = False

# Field set during _fitPsf:
self.psfFitFailed = False
Expand Down Expand Up @@ -240,6 +241,7 @@ def deblend(footprint, maskedImage, psf, psffwhm,
rampFluxAtEdge=False, patchEdges=False, tinyFootprintSize=2,
getTemplateSum=False,
clipStrayFluxFraction=0.001, clipFootprintToNonzero=True,
removeDegenerateTemplates=False, maxTempDotProd=0.5
):
"""!
Deblend a single ``footprint`` in a ``maskedImage``.
Expand Down Expand Up @@ -393,6 +395,121 @@ def deblend(footprint, maskedImage, psf, psffwhm,

pkres.setTemplate(t1, tfoot)

# Loop over fitting and identifying degenerate templates until no more objects are removed
while True:

if weightTemplates:
# Reweight the templates by doing a least-squares fit to the image
log.trace('Reweighting templates')
nchild = np.sum([pkres.skip is False for pkres in res.peaks])
A = np.zeros((W*H, nchild))
pimage = afwImage.ImageF(bb)
afwDet.copyWithinFootprintImage(fp, img, pimage)
b = pimage.getArray().ravel()

index = 0
for pkres in res.peaks:
if pkres.skip:
continue
cimage = afwImage.ImageF(bb)
afwDet.copyWithinFootprintImage(fp, pkres.templateImage, cimage)
A[:, index] = cimage.getArray().ravel()
index +=1

X1, r1, rank1, s1 = np.linalg.lstsq(A, b)
del A
del b

index = 0
for pkres in res.peaks:
if pkres.skip:
continue
pkres.templateImage *= X1[index]
pkres.setTemplateWeight(X1[index])
index += 1


exitLoop = True

# If galaxies have substructure, such as face-on spirals, the process of identifying peaks can
# "shred" the galaxy into many pieces. The templates of shredded galaxies are typically quite
# similiar because they represent the same galaxy. We try to identify these "degenerate" peaks
# by looking at the inner product (in pixel space) of pairs of templates. If they are nearly
# parallel, we only keep one of the peaks an reject the other.
if removeDegenerateTemplates:

log.trace('Looking for degnerate templates')

nchild = np.sum([pkres.skip is False for pkres in res.peaks])
indexes = [pkres.pki for pkres in res.peaks if pkres.skip is False]

# We build a matrix that stores the dot product between templates.
# We convert the template images to HeavyFootprints because they already have a method
# to compute the dot product.
A = np.zeros((nchild, nchild))
maxTemplate = []
heavies = []
for pkres in res.peaks:
if pkres.skip:
continue
heavies.append(afwDet.makeHeavyFootprint(pkres.templateFootprint,
afwImage.MaskedImageF(pkres.templateImage)))
maxTemplate.append(np.max(pkres.templateImage.getArray()))

for i in range(nchild):
for j in range(i+1):
A[i, j] = heavies[i].dot(heavies[j])

# Normalize the dot products to get the cosine of the angle between templates
for i in range(nchild):
for j in range(i):
norm = A[i, i]*A[j, j]
if norm <= 0:
A[i, j] = 0
else:
A[i, j] /= np.sqrt(norm)

# Iterate over pairs of objects and find the maximum non-diagonal element of the matrix. Exit the loop once we
# find a single degenerate pair greater than the threshold.
rejectedIndex = -1
foundReject = False
for i in range(nchild):
currentMax = 0.
for j in range(i):
if A[i, j] > currentMax:
currentMax = A[i, j]
if currentMax > maxTempDotProd:
foundReject = True
rejectedIndex = j

if foundReject:
break

del A

# If one of the objects is identified as a PSF keep the other one, otherwise keep the one
# with the maximum template value
if foundReject:
keep = indexes[i]
reject = indexes[rejectedIndex]
exitLoop = False
if res.peaks[keep].deblendedAsPsf and res.peaks[reject].deblendedAsPsf is False:
keep = indexes[rejectedIndex]
reject = indexes[i]
elif res.peaks[keep].deblendedAsPsf is False and res.peaks[reject].deblendedAsPsf:
reject = indexes[rejectedIndex]
keep = indexes[i]
else:
if maxTemplate[rejectedIndex] > maxTemplate[i]:
keep = indexes[rejectedIndex]
reject = indexes[i]
log.trace('Removing object with index %d : %f. Degenerate with %d' % (reject, currentMax, keep))
res.peaks[reject].skip = True
res.peaks[reject].degenerate = True

if exitLoop:
break

# Prepare inputs to "apportionFlux" call.
# template maskedImages
tmimgs = []
Expand All @@ -418,32 +535,6 @@ def deblend(footprint, maskedImage, psf, psffwhm,
pky.append(pk.getIy())
ibi.append(pkres.pki)

if weightTemplates:
# Reweight the templates by doing a least-squares fit to the image
A = np.zeros((W*H, len(tmimgs)))
pimage = afwImage.ImageF(bb)
afwDet.copyWithinFootprintImage(fp, img, pimage)
b = pimage.getArray().ravel()

index = 0
for pkres in res.peaks:
if pkres.skip:
continue
cimage = afwImage.ImageF(bb)
afwDet.copyWithinFootprintImage(fp, pkres.templateImage, cimage)
A[:, index] = cimage.getArray().ravel()
index +=1

X1, r1, rank1, s1 = np.linalg.lstsq(A, b)
del A
del b

for mim, i, w in zip(tmimgs, ibi, X1):
mim *= w
res.peaks[i].setTemplateWeight(w)

# FIXME -- Remove templates that are too similar (via dot-product test)?

# Now apportion flux according to the templates
log.trace('Apportioning flux among %i templates', len(tmimgs))
sumimg = afwImage.ImageF(bb)
Expand Down
12 changes: 10 additions & 2 deletions python/lsst/meas/deblender/deblend.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,13 @@ class SourceDeblendConfig(pexConf.Config):
"Sources violating this limit will not be deblended."),
)
weightTemplates = pexConf.Field(dtype=bool, default=False,
doc=('If true, a least-squares fit of the templates will be done'))
doc=("If true, a least-squares fit of the templates will be done"))
removeDegenerateTemplates = pexConf.Field(dtype=bool, default=False,
doc=("Try to remove similar tempates?"))
maxTempDotProd = pexConf.Field(dtype=float, default=0.5,
doc=("Threshold on the dot product between two templates that "
"determines whether they are degenerate. Pairs with values "
"larger than this will remove one of the templates."))

## \addtogroup LSST_task_documentation
## \{
Expand Down Expand Up @@ -309,7 +315,9 @@ def deblend(self, exposure, srcs, psf):
patchEdges=(self.config.edgeHandling == 'noclip'),
tinyFootprintSize=self.config.tinyFootprintSize,
clipStrayFluxFraction=self.config.clipStrayFluxFraction,
weightTemplates=self.config.weightTemplates
weightTemplates=self.config.weightTemplates,
removeDegenerateTemplates=self.config.removeDegenerateTemplates,
maxTempDotProd=self.config.maxTempDotProd
)
if self.config.catchFailures:
src.set(self.deblendFailedKey, False)
Expand Down
145 changes: 145 additions & 0 deletions tests/testDegenerateTemplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env python
#
# LSST Data Management System
#
# Copyright 2008-2016 AURA/LSST.
#
# This product includes software developed by the
# LSST Project (http://www.lsst.org/).
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the LSST License Statement and
# the GNU General Public License along with this program. If not,
# see <https://www.lsstcorp.org/LegalNotices/>.
#
from __future__ import print_function
import unittest
import numpy as np

import lsst.utils.tests
import lsst.afw.detection as afwDet
import lsst.afw.geom as afwGeom
import lsst.afw.image as afwImage
import lsst.pex.logging as pexLogging
from lsst.meas.deblender.baseline import deblend
import lsst.meas.algorithms as measAlg

root = pexLogging.Log.getDefaultLog()
root.setThreshold(pexLogging.Log.DEBUG)

# Quiet some of the more chatty loggers
pexLogging.Log(root, 'lsst.meas.deblender.symmetrizeFootprint',
pexLogging.Log.INFO)
pexLogging.Log(root, 'lsst.meas.deblender.symmetricFootprint',
pexLogging.Log.INFO)
pexLogging.Log(root, 'lsst.meas.deblender.getSignificantEdgePixels',
pexLogging.Log.INFO)
pexLogging.Log(root, 'afw.Mask',
pexLogging.Log.INFO)


def imExt(img):
bbox = img.getBBox()
return [bbox.getMinX(), bbox.getMaxX(),
bbox.getMinY(), bbox.getMaxY()]


def doubleGaussianPsf(W, H, fwhm1, fwhm2, a2):
return measAlg.DoubleGaussianPsf(W, H, fwhm1, fwhm2, a2)


def gaussianPsf(W, H, fwhm):
return measAlg.DoubleGaussianPsf(W, H, fwhm)


class DegenerateTemplateTestCase(lsst.utils.tests.TestCase):

def test1(self):
'''
A simple example: three overlapping blobs (detected as 1
footprint with three peaks). Additional peaks are added near
the blob peaks that should be identified as degenerate.
'''
H, W = 100, 100

fpbb = afwGeom.Box2I(afwGeom.Point2I(0, 0),
afwGeom.Point2I(W-1, H-1))

afwimg = afwImage.MaskedImageF(fpbb)
imgbb = afwimg.getBBox()
img = afwimg.getImage().getArray()

var = afwimg.getVariance().getArray()
var[:, :] = 1.

blob_fwhm = 10.
blob_psf = doubleGaussianPsf(99, 99, blob_fwhm, 2.*blob_fwhm, 0.03)

fakepsf_fwhm = 3.
fakepsf = gaussianPsf(11, 11, fakepsf_fwhm)

blobimgs = []
x = 75.
XY = [(x, 35.), (x, 65.), (50., 50.)]
flux = 1e6
for x, y in XY:
bim = blob_psf.computeImage(afwGeom.Point2D(x, y))
bbb = bim.getBBox()
bbb.clip(imgbb)

bim = bim.Factory(bim, bbb)
bim2 = bim.getArray()

blobimg = np.zeros_like(img)
blobimg[bbb.getMinY():bbb.getMaxY()+1,
bbb.getMinX():bbb.getMaxX()+1] += flux * bim2
blobimgs.append(blobimg)

img[bbb.getMinY():bbb.getMaxY()+1,
bbb.getMinX():bbb.getMaxX()+1] += flux * bim2

# Run the detection code to get a ~ realistic footprint
thresh = afwDet.createThreshold(5., 'value', True)
fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1)
fps = fpSet.getFootprints()
print('found', len(fps), 'footprints')

# Add new peaks near to the first peaks that will be degenerate
fp0 = fps[0]
for x, y in XY:
fp0.addPeak(x-10, y+6, 10)

for fp in fps:
print('peaks:', len(fp.getPeaks()))
for pk in fp.getPeaks():
print(' ', pk.getIx(), pk.getIy())


deb = deblend(fp0, afwimg, fakepsf, fakepsf_fwhm, verbose=True, removeDegenerateTemplates=True)

self.assertTrue(deb.peaks[3].degenerate)
self.assertTrue(deb.peaks[4].degenerate)
self.assertTrue(deb.peaks[5].degenerate)

#-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-


class TestMemory(lsst.utils.tests.MemoryTestCase):
pass


def setup_module(module):
lsst.utils.tests.init()

if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()

0 comments on commit 4e9d871

Please sign in to comment.