Skip to content

Commit

Permalink
Add removal of degnerate templates
Browse files Browse the repository at this point in the history
Compute the dot product between all pairs of templates and
remove one of the objects if the dot product is larger than
threshold.
  • Loading branch information
rearmstr committed Nov 29, 2016
1 parent e87244f commit f5ca2ab
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 28 deletions.
143 changes: 117 additions & 26 deletions python/lsst/meas/deblender/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self):
self.tinyFootprint = False
self.noValidPixels = False
self.deblendedAsPsf = False
self.degenerate = False

# Field set during _fitPsf:
self.psfFitFailed = False
Expand Down Expand Up @@ -240,6 +241,7 @@ def deblend(footprint, maskedImage, psf, psffwhm,
rampFluxAtEdge=False, patchEdges=False, tinyFootprintSize=2,
getTemplateSum=False,
clipStrayFluxFraction=0.001, clipFootprintToNonzero=True,
removeDegenerateTemplates=False, maxTempDotProd=0.5
):
"""!
Deblend a single ``footprint`` in a ``maskedImage``.
Expand Down Expand Up @@ -393,6 +395,121 @@ def deblend(footprint, maskedImage, psf, psffwhm,

pkres.setTemplate(t1, tfoot)

# Loop over fitting and identifying degenerate templates until no more objects are removed
while True:

if weightTemplates:
# Reweight the templates by doing a least-squares fit to the image
log.trace('Reweighting templates')
nchild = np.sum([pkres.skip is False for pkres in res.peaks])
A = np.zeros((W*H, nchild))
parentImage = afwImage.ImageF(bb)
afwDet.copyWithinFootprintImage(fp, img, parentImage)
b = parentImage.getArray().ravel()

index = 0
for pkres in res.peaks:
if pkres.skip:
continue
childImage = afwImage.ImageF(bb)
afwDet.copyWithinFootprintImage(fp, pkres.templateImage, childImage)
A[:, index] = childImage.getArray().ravel()
index += 1

X1, r1, rank1, s1 = np.linalg.lstsq(A, b)
del A
del b

index = 0
for pkres in res.peaks:
if pkres.skip:
continue
pkres.templateImage *= X1[index]
pkres.setTemplateWeight(X1[index])
index += 1

exitLoop = True

# If galaxies have substructure, such as face-on spirals, the process of identifying peaks can
# "shred" the galaxy into many pieces. The templates of shredded galaxies are typically quite
# similiar because they represent the same galaxy. We try to identify these "degenerate" peaks
# by looking at the inner product (in pixel space) of pairs of templates. If they are nearly
# parallel, we only keep one of the peaks an reject the other.
if removeDegenerateTemplates:

log.trace('Looking for degnerate templates')

nchild = np.sum([pkres.skip is False for pkres in res.peaks])
indexes = [pkres.pki for pkres in res.peaks if pkres.skip is False]

# We build a matrix that stores the dot product between templates.
# We convert the template images to HeavyFootprints because they already have a method
# to compute the dot product.
A = np.zeros((nchild, nchild))
maxTemplate = []
heavies = []
for pkres in res.peaks:
if pkres.skip:
continue
heavies.append(afwDet.makeHeavyFootprint(pkres.templateFootprint,
afwImage.MaskedImageF(pkres.templateImage)))
maxTemplate.append(np.max(pkres.templateImage.getArray()))

for i in range(nchild):
for j in range(i + 1):
A[i, j] = heavies[i].dot(heavies[j])

# Normalize the dot products to get the cosine of the angle between templates
for i in range(nchild):
for j in range(i):
norm = A[i, i]*A[j, j]
if norm <= 0:
A[i, j] = 0
else:
A[i, j] /= np.sqrt(norm)

# Iterate over pairs of objects and find the maximum non-diagonal element of the matrix.
# Exit the loop once we find a single degenerate pair greater than the threshold.
rejectedIndex = -1
foundReject = False
for i in range(nchild):
currentMax = 0.
for j in range(i):
if A[i, j] > currentMax:
currentMax = A[i, j]
if currentMax > maxTempDotProd:
foundReject = True
rejectedIndex = j

if foundReject:
break

del A

# If one of the objects is identified as a PSF keep the other one, otherwise keep the one
# with the maximum template value
if foundReject:
keep = indexes[i]
reject = indexes[rejectedIndex]
exitLoop = False
if res.peaks[keep].deblendedAsPsf and res.peaks[reject].deblendedAsPsf is False:
keep = indexes[rejectedIndex]
reject = indexes[i]
elif res.peaks[keep].deblendedAsPsf is False and res.peaks[reject].deblendedAsPsf:
reject = indexes[rejectedIndex]
keep = indexes[i]
else:
if maxTemplate[rejectedIndex] > maxTemplate[i]:
keep = indexes[rejectedIndex]
reject = indexes[i]
log.trace('Removing object with index %d : %f. Degenerate with %d' % (reject, currentMax,
keep))
res.peaks[reject].skip = True
res.peaks[reject].degenerate = True

if exitLoop:
break

# Prepare inputs to "apportionFlux" call.
# template maskedImages
tmimgs = []
Expand All @@ -418,32 +535,6 @@ def deblend(footprint, maskedImage, psf, psffwhm,
pky.append(pk.getIy())
ibi.append(pkres.pki)

if weightTemplates:
# Reweight the templates by doing a least-squares fit to the image
A = np.zeros((W*H, len(tmimgs)))
pimage = afwImage.ImageF(bb)
afwDet.copyWithinFootprintImage(fp, img, pimage)
b = pimage.getArray().ravel()

index = 0
for pkres in res.peaks:
if pkres.skip:
continue
cimage = afwImage.ImageF(bb)
afwDet.copyWithinFootprintImage(fp, pkres.templateImage, cimage)
A[:, index] = cimage.getArray().ravel()
index +=1

X1, r1, rank1, s1 = np.linalg.lstsq(A, b)
del A
del b

for mim, i, w in zip(tmimgs, ibi, X1):
mim *= w
res.peaks[i].setTemplateWeight(w)

# FIXME -- Remove templates that are too similar (via dot-product test)?

# Now apportion flux according to the templates
log.trace('Apportioning flux among %i templates', len(tmimgs))
sumimg = afwImage.ImageF(bb)
Expand Down
15 changes: 13 additions & 2 deletions python/lsst/meas/deblender/deblend.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,16 @@ class SourceDeblendConfig(pexConf.Config):
"Sources violating this limit will not be deblended."),
)
weightTemplates = pexConf.Field(dtype=bool, default=False,
doc=('If true, a least-squares fit of the templates will be done'))
doc=("If true, a least-squares fit of the templates will be done to the "
"full image. The templates will be re-weighted based on this fit."))
removeDegenerateTemplates = pexConf.Field(dtype=bool, default=False,
doc=("Try to remove similar templates?"))
maxTempDotProd = pexConf.Field(dtype=float, default=0.5,
doc=("If the dot product between two templates are larger than this value"
", we consider them to be describing the same object (i.e. they are "
"degenerate). If one of the objects has been labeled as a PSF it "
"will be removed, otherwise the template with the lowest value will "
"be removed."))

## \addtogroup LSST_task_documentation
## \{
Expand Down Expand Up @@ -309,7 +318,9 @@ def deblend(self, exposure, srcs, psf):
patchEdges=(self.config.edgeHandling == 'noclip'),
tinyFootprintSize=self.config.tinyFootprintSize,
clipStrayFluxFraction=self.config.clipStrayFluxFraction,
weightTemplates=self.config.weightTemplates
weightTemplates=self.config.weightTemplates,
removeDegenerateTemplates=self.config.removeDegenerateTemplates,
maxTempDotProd=self.config.maxTempDotProd
)
if self.config.catchFailures:
src.set(self.deblendFailedKey, False)
Expand Down
119 changes: 119 additions & 0 deletions tests/testDegenerateTemplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#!/usr/bin/env python
#
# LSST Data Management System
#
# Copyright 2008-2016 AURA/LSST.
#
# This product includes software developed by the
# LSST Project (http://www.lsst.org/).
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the LSST License Statement and
# the GNU General Public License along with this program. If not,
# see <https://www.lsstcorp.org/LegalNotices/>.
#
from __future__ import print_function
import unittest
import numpy as np

import lsst.utils.tests
import lsst.afw.detection as afwDet
import lsst.afw.geom as afwGeom
import lsst.afw.image as afwImage
from lsst.meas.deblender.baseline import deblend
import lsst.meas.algorithms as measAlg

def imExt(img):
bbox = img.getBBox()
return [bbox.getMinX(), bbox.getMaxX(), bbox.getMinY(), bbox.getMaxY()]

def doubleGaussianPsf(W, H, fwhm1, fwhm2, a2):
return measAlg.DoubleGaussianPsf(W, H, fwhm1, fwhm2, a2)

def gaussianPsf(W, H, fwhm):
return measAlg.DoubleGaussianPsf(W, H, fwhm)

class DegenerateTemplateTestCase(lsst.utils.tests.TestCase):

def testPeakRemoval(self):
'''
A simple example: three overlapping blobs (detected as 1
footprint with three peaks). Additional peaks are added near
the blob peaks that should be identified as degenerate.
'''
H, W = 100, 100

fpbb = afwGeom.Box2I(afwGeom.Point2I(0, 0), afwGeom.Point2I(W - 1, H - 1))

afwimg = afwImage.MaskedImageF(fpbb)
imgbb = afwimg.getBBox()
img = afwimg.getImage().getArray()

var = afwimg.getVariance().getArray()
var[:, :] = 1.

blob_fwhm = 10.
blob_psf = doubleGaussianPsf(99, 99, blob_fwhm, 2.*blob_fwhm, 0.03)

fakepsf_fwhm = 3.
fakepsf = gaussianPsf(11, 11, fakepsf_fwhm)

blobimgs = []
x = 75.
XY = [(x, 35.), (x, 65.), (50., 50.)]
flux = 1e6
for x, y in XY:
bim = blob_psf.computeImage(afwGeom.Point2D(x, y))
bbb = bim.getBBox()
bbb.clip(imgbb)

bim = bim.Factory(bim, bbb)
bim2 = bim.getArray()

blobimg = np.zeros_like(img)
blobimg[bbb.getMinY():bbb.getMaxY()+1, bbb.getMinX():bbb.getMaxX()+1] += flux*bim2
blobimgs.append(blobimg)

img[bbb.getMinY():bbb.getMaxY()+1,
bbb.getMinX():bbb.getMaxX()+1] += flux * bim2

# Run the detection code to get a ~ realistic footprint
thresh = afwDet.createThreshold(5., 'value', True)
fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1)
fps = fpSet.getFootprints()

self.assertTrue(len(fps) == 1)

# Add new peaks near to the first peaks that will be degenerate
fp0 = fps[0]
for x, y in XY:
fp0.addPeak(x - 10, y + 6, 10)

deb = deblend(fp0, afwimg, fakepsf, fakepsf_fwhm, verbose=True, removeDegenerateTemplates=True)

self.assertTrue(deb.peaks[3].degenerate)
self.assertTrue(deb.peaks[4].degenerate)
self.assertTrue(deb.peaks[5].degenerate)

#-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-


class TestMemory(lsst.utils.tests.MemoryTestCase):
pass


def setup_module(module):
lsst.utils.tests.init()

if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()

0 comments on commit f5ca2ab

Please sign in to comment.