Skip to content

Commit

Permalink
Merge pull request #9 from lsst/tickets/DM-20129
Browse files Browse the repository at this point in the history
tickets/DM-20129
  • Loading branch information
fred3m committed Jul 25, 2019
2 parents 2ad8b30 + 6b94a32 commit 7ac3d52
Show file tree
Hide file tree
Showing 9 changed files with 469 additions and 147 deletions.
18 changes: 8 additions & 10 deletions python/lsst/meas/extensions/scarlet/blend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


from scarlet.blend import Blend


Expand All @@ -11,10 +9,10 @@ class LsstBlend(Blend):
for multiresolution blends. So this class exists for any
LSST specific changes.
"""
def get_model(self, *parameters, observation=None):
model = super().get_model(*parameters)
def get_model(self, seds=None, morphs=None, observation=None):
model = super().get_model(seds, morphs)
if observation is not None:
model = observation.get_model(model)
model = observation.render(model)
return model

def display_model(self, observation=None, ax=None, filters=None, Q=10, stretch=1, show=True):
Expand All @@ -27,10 +25,10 @@ def display_model(self, observation=None, ax=None, filters=None, Q=10, stretch=1
ax = fig.add_subplot(1, 1, 1)
if filters is None:
filters = [2, 1, 0]
img_rgb = make_lupton_rgb(image_r=model[filters[0]], # numpy array for the r channel
image_g=model[filters[1]], # numpy array for the g channel
image_b=model[filters[2]], # numpy array for the b channel
stretch=stretch, Q=Q) # parameters used to stretch and scale the values
ax.imshow(img_rgb, interpolation='nearest')
imgRgb = make_lupton_rgb(image_r=model[filters[0]], # numpy array for the r channel
image_g=model[filters[1]], # numpy array for the g channel
image_b=model[filters[2]], # numpy array for the b channel
stretch=stretch, Q=Q) # parameters used to stretch and scale the values
ax.imshow(imgRgb, interpolation='nearest')
if show:
plt.show()
26 changes: 13 additions & 13 deletions python/lsst/meas/extensions/scarlet/deblend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .source import LsstSource, LsstHistory
from .blend import LsstBlend
from .observation import LsstScene, LsstObservation
from .observation import LsstFrame, LsstObservation


logger = lsst.log.Log.getLogger("meas.deblender.deblend")
Expand All @@ -27,17 +27,17 @@ def _getPsfFwhm(psf):
return psf.computeShape().getDeterminantRadius() * 2.35


def _estimateStdDev(exposure, statsMask):
def _estimateRMS(exposure, statsMask):
"""Estimate the standard dev. of an image
Take the median standard deviation of the `exposure`.
Calculate the RMS of the `exposure`.
"""
mi = exposure.getMaskedImage()
statsCtrl = afwMath.StatisticsControl()
statsCtrl.setAndMask(mi.getMask().getPlaneBitMask(statsMask))
stats = afwMath.makeStatistics(mi.variance, mi.mask, afwMath.MEDIAN, statsCtrl)
sigma = np.sqrt(stats.getValue(afwMath.MEDIAN))
return sigma
stats = afwMath.makeStatistics(mi.variance, mi.mask, afwMath.STDEV | afwMath.MEAN, statsCtrl)
rms = np.sqrt(stats.getValue(afwMath.MEAN)**2 + stats.getValue(afwMath.STDEV)**2)
return rms


def _getTargetPsf(shape, sigma=1/np.sqrt(2)):
Expand Down Expand Up @@ -76,21 +76,21 @@ def deblend(mExposure, footprint, log, config):
psfs = mExposure.computePsfImage(footprint.getCentroid()).array
target_psf = _getTargetPsf(psfs.shape)

observation = LsstObservation(images, psfs, weights)
scene = LsstScene(images.shape, psfs=target_psf)
bg_rms = np.array([_estimateStdDev(exposure, config.statsMask) for exposure in mExposure[:, bbox]])
frame = LsstFrame(images.shape, psfs=target_psf[None])
observation = LsstObservation(images, psfs, weights).match(frame)
bgRms = np.array([_estimateRMS(exposure, config.statsMask) for exposure in mExposure[:, bbox]])
if config.storeHistory:
Source = LsstHistory
else:
Source = LsstSource
sources = [
Source(peak=center, scene=scene, observations=observation, bg_rms=bg_rms,
Source(frame=frame, peak=center, observation=observation, bgRms=bgRms,
bbox=bbox, symmetric=config.symmetric, monotonic=config.monotonic,
center_step=config.recenterPeriod)
centerStep=config.recenterPeriod)
for center in footprint.peaks
]

blend = LsstBlend(scene, sources, observation)
blend = LsstBlend(sources, observation)
blend.fit(config.maxIter, config.relativeError, False)

return blend
Expand Down Expand Up @@ -124,7 +124,7 @@ class ScarletDeblendConfig(pexConfig.Config):
# Constraints
sparse = pexConfig.Field(dtype=bool, default=True, doc="Make models compact and sparse")
monotonic = pexConfig.Field(dtype=bool, default=True, doc="Make models monotonic")
symmetric = pexConfig.Field(dtype=bool, default=False, doc="Make models symmetric")
symmetric = pexConfig.Field(dtype=bool, default=True, doc="Make models symmetric")
symmetryThresh = pexConfig.Field(dtype=float, default=1.0,
doc=("Strictness of symmetry, from"
"0 (no symmetry enforced) to"
Expand Down
5 changes: 3 additions & 2 deletions python/lsst/meas/extensions/scarlet/observation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from scarlet.observation import Scene, Observation

from scarlet.observation import Frame, Observation

class LsstScene(Scene):

class LsstFrame(Frame):
pass


Expand Down
45 changes: 25 additions & 20 deletions python/lsst/meas/extensions/scarlet/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,31 @@ class LsstSource(ExtendedSource):
default initialization and update constraints for general sources in
LSST images.
"""
def __init__(self, peak, scene, observations, bg_rms, bbox, obs_idx=0,
thresh=1, symmetric=False, monotonic=True, center_step=5,
**component_kwargs):
def __init__(self, frame, peak, observation, bgRms, bbox,
thresh=1, symmetric=True, monotonic=True, centerStep=5,
pointSource=False, **componentKwargs):
xmin = bbox.getMinX()
ymin = bbox.getMinY()
sky_coord = np.array([peak.getIy()-ymin, peak.getIx()-xmin], dtype=int)
try:
super().__init__(sky_coord, scene, observations, bg_rms, obs_idx, thresh,
symmetric, monotonic, center_step, **component_kwargs)
except SourceInitError:
# If the source is too faint for background detection, initialize
# it as a PointSource
PointSource.__init__(self, sky_coord, scene, observations, symmetric, monotonic,
center_step, **component_kwargs)
center = np.array([peak.getIy()-ymin, peak.getIx()-xmin], dtype=int)
initialized = False
if not pointSource:
try:
super().__init__(frame, center, observation, bgRms, thresh,
symmetric, monotonic, centerStep, **componentKwargs)
initialized = True
except SourceInitError:
# If the source is too faint for background detection,
# initialize it as a PointSource
pass
if not initialized:
PointSource.__init__(self, frame, center, observation, symmetric, monotonic,
centerStep, **componentKwargs)
self.detectedPeak = peak

def get_model(self, *parameters, observation=None):
model = super().get_model(*parameters)
def get_model(self, sed=None, morph=None, observation=None):
model = super().get_model(sed, morph)
if observation is not None:
model = observation.get_model(model)
model = observation.render(model)
return model

def display_model(self, observation=None, ax=None, filters=None, Q=10, stretch=1, show=True):
Expand All @@ -47,11 +52,11 @@ def display_model(self, observation=None, ax=None, filters=None, Q=10, stretch=1
ax = fig.add_subplot(1, 1, 1)
if filters is None:
filters = [2, 1, 0]
img_rgb = make_lupton_rgb(image_r=model[filters[0]], # numpy array for the r channel
image_g=model[filters[1]], # numpy array for the g channel
image_b=model[filters[2]], # numpy array for the b channel
stretch=stretch, Q=Q) # parameters used to stretch and scale the values
ax.imshow(img_rgb, interpolation='nearest')
imgRgb = make_lupton_rgb(image_r=model[filters[0]], # numpy array for the r channel
image_g=model[filters[1]], # numpy array for the g channel
image_b=model[filters[2]], # numpy array for the b channel
stretch=stretch, Q=Q) # parameters used to stretch and scale the values
ax.imshow(imgRgb, interpolation='nearest')
if show:
plt.show()

Expand Down
119 changes: 119 additions & 0 deletions tests/test_deblend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@


# This file is part of meas_extensions_scarlet.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import unittest

import numpy as np

import lsst.utils.tests
import lsst.afw.image as afwImage
from lsst.meas.algorithms import SourceDetectionTask
from lsst.meas.extensions.scarlet import ScarletDeblendTask
from lsst.afw.table import SourceCatalog
from lsst.afw.detection import MultibandFootprint
from lsst.afw.image import Image, MultibandImage

from utils import initData


class TestDeblend(lsst.utils.tests.TestCase):
def test_deblend_task(self):
# Set the random seed so that the noise field is unaffected
np.random.seed(0)
# Test that executing the deblend task works
# In the future we can have more detailed tests,
# but for now this at least ensures that the task isn't broken
shape = (5, 31, 55)
coords = [(15, 25), (10, 30), (17, 38)]
amplitudes = [80, 60, 90]
result = initData(shape, coords, amplitudes)
targetPsfImage, psfImages, images, channels, seds, morphs, targetPsf, psfs = result
B, Ny, Nx = shape

# Add some noise, otherwise the task will blow up due to
# zero variance
noise = 10*(np.random.rand(*images.shape)-.5)
images += noise

filters = "grizy"
_images = afwImage.MultibandMaskedImage.fromArrays(filters, images.astype(np.float32), None,
noise.astype(np.float32))
coadds = [afwImage.Exposure(img, dtype=img.image.array.dtype) for img in _images]
coadds = afwImage.MultibandExposure.fromExposures(filters, coadds)
for b, coadd in enumerate(coadds):
coadd.setPsf(psfs[b])

schema = SourceCatalog.Table.makeMinimalSchema()

detectionTask = SourceDetectionTask(schema=schema)
config = ScarletDeblendTask.ConfigClass()
config.maxIter = 200
deblendTask = ScarletDeblendTask(schema=schema, config=config)

table = SourceCatalog.Table.make(schema)
detectionResult = detectionTask.run(table, coadds["r"])
catalog = detectionResult.sources
self.assertEqual(len(catalog), 1)
_, result = deblendTask.run(coadds, catalog)

# Changes to the internal workings of scarlet will change these results
# however we include these tests just to track changes
parent = result["r"][0]
self.assertEqual(parent["iterations"], 11)
self.assertEqual(parent["deblend_nChild"], 3)

heavies = []
for k in range(1, len(result["g"])):
heavy = MultibandFootprint(coadds.filters, [result[b][k].getFootprint() for b in filters])
heavies.append(heavy)

seds = np.array([heavy.getImage(fill=0).image.array.sum(axis=(1, 2)) for heavy in heavies])
true_seds = np.array([
[[1665.726318359375, 1745.5401611328125, 1525.91796875, 997.3868408203125, 0.0],
[767.100341796875, 1057.0374755859375, 1312.89111328125, 1694.7535400390625, 2069.294921875],
[8.08012580871582, 879.344970703125, 2246.90087890625, 4212.82470703125, 6987.0849609375]]
])

self.assertFloatsAlmostEqual(true_seds, seds, rtol=1e-8, atol=1e-8)

bbox = parent.getFootprint().getBBox()
data = coadds[:, bbox]
model = MultibandImage.fromImages(coadds.filters, [
Image(bbox, dtype=np.float32)
for b in range(len(filters))
])
for heavy in heavies:
model[:, heavy.getBBox()].array += heavy.getImage(fill=0).image.array

residual = data.image.array - model.array
self.assertFloatsAlmostEqual(np.abs(residual).sum(), 11601.3867187500)
self.assertFloatsAlmostEqual(np.max(np.abs(residual)), 56.1048278809, rtol=1e-8, atol=1e-8)


class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass


if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()

0 comments on commit 7ac3d52

Please sign in to comment.