# Deblending with *Scarlet*
<br>Owner(s): **Fred Moolekamp** ([@fred3m](https://github.com/LSSTScienceCollaborations/StackClub/issues/new?body=@fred3m))
<br>Last Verified to Run: **2020-07-10**
<br>Verified Stack Release: **v20.0.0**

The purpose of this tutorial is to familiarize you with the basics of using *scarlet* to model blended scenes, and how tweaking various objects and parameters affects the resulting model. A tutorial that is more specific to using scarlet in the context of the LSST DM Science Pipelines is also available.

### Learning Objectives:

After working through this tutorial you should be able to: 
1. Configure and run _scarlet_ on a test list of objects;
2. Understand its various model assumptions and applied constraints.
3. Use specific configurations to fit objects of different nature (stars, galaxies, LSBG)
4. Bonus: we would like to give users a sense of how they can use their own assumptions to build models in scarlet

Before attempting this tutorial it will be useful to read the [introduction](https://fred3m.github.io/scarlet/user_docs.html) to the *scarlet* User Guide, and many of the exercises below may require referencing the *scarlet* [docs](https://fred3m.github.io/scarlet/).

### Logistics
This notebook is intended to be runnable on `cori.nersc.gov` from a local git clone of https://github.com/LSSTDESC/StackClub.

## Set-up

In [None]:
# What version of the Stack are we using?
! echo $HOSTNAME
! eups list -s | grep lsst_distrib

In [None]:
# Import the necessary libraries
import os

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
# don't interpolate the pixels
matplotlib.rc('image', interpolation='none')

import numpy as np
from astropy.visualization.lupton_rgb import AsinhMapping

import scarlet
import scarlet.display
from scarlet import Starlet
import pickle

from astropy.visualization.lupton_rgb import AsinhMapping, LinearMapping

We will also load the butler and various lsst packages

In [None]:
import numpy as np
import lsst
from lsst.daf.persistence import Butler
from lsst.geom import Box2I, Box2D, Point2I, Point2D, Extent2I, Extent2D
from lsst.afw.image import Exposure, Image, PARENT, MultibandExposure, MultibandImage
from lsst.afw.detection import MultibandFootprint
from lsst.afw.image import MultibandExposure

# Load and Display the data

More information are provided in the `lsst_stack_deblender.ipynb` tutorial. 

The **butler** is used to recover data from DESC DC2 DR6 by specifying the tract, patch and filters.

In [None]:
import desc_dc2_dm_data
butler = desc_dc2_dm_data.get_butler("2.2i_dr6_wfd")
dataId = {"tract": 3830, "patch": "4,4"}
filters = "ugrizy"
coadds = [butler.get("deepCoadd_calexp", dataId, filter=f) for f in filters]
coadds = MultibandExposure.fromExposures(filters, coadds)

We then display the patch of data using display functions built-in in scarlet. The `norm` is used to create a colour scaling that avoids whitening the center of bright-ish objects also known as _Luptonisation_. 

The `image_to_rgb` function maps multi-band arrays into 3-channel RGB images. The mapping is done automatically by default, assuming an ordering of bands from bluer to redder, but it can be customised using the `channel_map` argument.

In [None]:
norm = scarlet.display.AsinhMapping(minimum=0, stretch=1, Q=10)
rgb_patch = scarlet.display.img_to_rgb(coadds.image.array, norm=norm)

In [None]:
plt.figure(figsize = (30,30))
plt.imshow(rgb_patch, origin = "lower")
plt.xticks(fontsize = 30)
plt.yticks(fontsize = 30)
plt.show()

Let's formalise the previous display procedure into a function

In [None]:
def imshow_rgb(image, norm = None, figsize = None, cat = None):
    """
    Parameters
    ----------
    image: `numpy.ndarrray`
        Multiband image to display
    norm: `scarlet.display.AsinhMapping`
    """
    if norm == None:
        norm = scarlet.display.AsinhMapping(minimum=0, stretch=1, Q=10)
    rgb_patch = scarlet.display.img_to_rgb(image, norm=norm)
        
    plt.figure(figsize = figsize)
    plt.imshow(rgb_patch, origin = "lower")
    plt.xticks(fontsize = 30)
    plt.yticks(fontsize = 30)
    #Display source positions
    if cat != None:
        plt.plot(cat['x'], cat['y'], 'wx', markersize = 15)
    plt.show()
    pass

## Display a subset

Here we extract a patch from the previous image to run scarlet on.

In [None]:
n = 180
sampleBBox = Box2I(Point2I(16880, 19320), Extent2I(n, n))

subset = coadds[:, sampleBBox]
# Due to a bug in the code the PSF isn't copied properly.
# The code below copies the PSF into the `MultibandExposure`,
# but will be unecessary in the future
for f in subset.filters:
    subset[f].setPsf(coadds[f].getPsf())

In [None]:
images = subset.image.array
imshow_rgb(images, figsize = (15,15))

# Display the psfs

In [None]:
psfs = subset.computePsfImage(Point2I(16880, 19320)).array
psf_norm = scarlet.display.AsinhMapping(minimum=0, stretch=0.001, Q=10)
imshow_rgb(psfs, norm = psf_norm, figsize = (15, 15))

# Scarlet background

This short introduction to scarlet concepts is adapted from a notebook built by fred Moolekamp and Peter Melchior:
    https://github.com/pmelchior/scarlet/blob/master/docs/1-concepts.ipynb.
    
Scarlet aims at modeling individual sourcess $S$ in image $Y$, convolved by the instrument psf $H$ and which contains additive noise $N$. The model for images is therefore: $$Y = H\sum_i S_i + N$$

In scarlet, images can be a cube of images observed at different wavelength. Scarlet is able to exploit the colour of each source to reconstruct invidual models by factorising a source into a spectra and a morphlogy such that, for a given source: $$S_i = A_iM_i$$ Here, $S_i$ is an array with the shape of the observations ($n_{channels}\times n_{xpixels} \times n_{ypixels}$), $A_i$ contains the $n_{channels}$ elements that make up the spectra of sources $S_i$ and $M_i$ contains the morphology of the source. The array that contains the morphology is classically a 2-D array with the same number of pixels as the observation.

This decomposition is a choice made by the users of scarlet motivated the colour differences  betweeen sources in astronomical observations. However, it is possible to come up with other definition for sources that do not follow the previous predicaments. For instance, one could think of modelling each source in its entirety as an array $S_i$ that contains as many pixels as the observation, or to model each source as a linear decomposition over an arbitrary basis or as a non-linear analytic profile. 

Scarlet has the flexibility to enable these models, but it requires setting up the tools to go from an arbitrary model to the data. 


# Set the weights for inverse variance weighting

Bands with higher noise variance are less informative to the fit, they are therefore down-waited in the optimisation. Here the variance is given for each pixel position.

In [None]:
var = subset.variance.array
weights = 1 / (var ** 2)

# Define the model and observation frames

A key concept is the interplay between `Frame` and `Observation`. While the data have their native resolution with a given pixel size and PSF, a model for galaxies can be built at arbitrary resolution. The `Frame` allows to define a sampling for the model and a psf to which the models are deconvolved 

A `Frame` in scarlet is the metadata that describes where the images of the model lives. It includes the frames shape, wcs (optional), and the PSF (technically optional but strongly recommended). 

The `Observation` defines where the data lives, but also how to go from the model frame to the data. It contains the data themselves as an array but also meta-information such as channels tags, psf (technically optional but strongly recommended), wcs(optional) and weights (optional). The `Observation` needs to be matched to the `Frame` through the `match()` method. 
In scarlet it is possible to deblend scenes that have observations with different instruments that have different resolutions and/or observations that have not been coadded by building a list of `Observation`s. The `Frame` can be automatically built from the list of `Observation`s, however that is outside the scope of this tutorial and the interested reader should be referred to https://fred3m.github.io/scarlet/tutorials/multiresolution.html.

So we will create an initial model `Frame` that uses a narrow gaussian PSF and an `Observation` that consists of multiple bands of an HSC coadded image.

See https://fred3m.github.io/scarlet/user_docs.html#Frame-and-Observation for more on `Frame`s and `Observation`s.

In [None]:
# Create a PSF image of a narrow gaussian to use as our image PSF
channels = [f for f in filters]
# Create a `scarlet.PSF` object
model_psf = scarlet.GaussianPSF(sigma=0.9)

# Create the initial frame (metadata for the model).
frame = scarlet.Frame(images.shape, psfs=model_psf, channels=filters)

# Create our observation
observation = scarlet.Observation(images, psfs=scarlet.ImagePSF(psfs), channels=filters, weights=weights)

The `Observation` has to be matched to the `Frame`. The role of this operation is to create the diff-kernel that will deconvolve the data from their psf to the model psf (in the `Frame`). After this step we will be able to visualise the diff kernel.

In [None]:
#Initially the diff kernel is initialised to None
print(f'the diff kernel does not exist: {observation._diff_kernels}')
#After matching, the diff kernel is instantiated
observation.match(frame)
imshow_rgb(observation._diff_kernels.image, norm = psf_norm, figsize = (15,15))

# Sources

The next key concept to scarlet is that of `Source`s. `Source` objects describe how the model is parametrized and what constraints are used in the optimisation. A source is related to a position on the image grid and has a box size that determines the span of an object's light profile. 

## Detection

One `Source` object corresponds to the light profile of an astronomical object. Scarlet therefore requires that objects be detected so that a source can be initialised. Scarlet requires a position in order to declare and initiate each source. Scarlet does not have a detection algorithm of its own and it is outside the scope of this tool, but we will provide here a custom detection. 

extension tools for scarlet such as this detection function can be found in the [scarlet  extensions](https://github.com/fred3m/scarlet_extensions) package.

In [None]:
def makeCatalog(cube, lvl=3, wave=True):
    ''' Creates a detection catalog by combining low and high resolution data
    This function is used for detection before running scarlet.
    It is particularly useful for stellar crowded fields and for detecting high frequency features.
    Parameters
    ----------
    datas: array
        array of Data objects
    lvl: int
        detection lvl
    wave: Bool
        set to True to use wavelet decomposition of images before combination
    Returns
    -------
    catalog: sextractor catalog
        catalog of detected sources
    bg_rms: array
        background level for each data set
    '''
    #Coadd all bands
    detect_image = np.sum(cube, axis=0)

    if np.size(detect_image.shape) == 3:
        if wave:
            # Wavelet detection in the first three levels
            #Consider this a high pass filter
            wave_detect = Starlet(detect_image.mean(axis=0), lvl=4).coefficients
            wave_detect[:, -1, :, :] = 0
            detect = Starlet(coefficients=wave_detect).image
        else:
            # Direct detection
            detect = detect_image.mean(axis=0)
    else:
        if wave:
            wave_detect = Starlet(detect_image).coefficients
            detect = wave_detect[0][0]
        else:
            detect = detect_image

    bkg = sep.Background(detect)
    catalog = sep.extract(detect, lvl, err=bkg.globalrms)

    if len(datas) ==1:
        bg_rms = mad_wavelet(datas[0].images)
    else:
        bg_rms = []
        for data in datas:
            bg_rms.append(mad_wavelet(data.images))

    return catalog

We would normally run the detection algorithm on the DC2 images and display the result of the detection. However, this custom function requires sep which is not available on `desc-stack-weekly-latest`. Instead I ran the detection on my machine and show the result here.

In [None]:
#cat = makeCatalog(images)
cat = pickle.load(open('cat.pkl', 'rb'))
imshow_rgb(images, figsize = (20,20), cat = cat)


## Initializing Sources

Astrophysical objects are generally modeled in scarlet as a collection of factorized components, where each component has a single SED that is constant over it's morphology (band independent intensity). So a single source might have multiple components, like a bulge and disk, or a single component.

The different classes that inherit from `FactorizedComponent` differ in how they are initialized and parametrized. This section illustrates the differences between different source initialization classes.

### <span style="color:red"> *WARNING* </span>
Scarlet accepts source positions using the numpy/C++ convention of (y,x), which is different than the astropy and LSST stack convention of (x,y).

Below we demonstrate the usage of `ExtendedSource`, which initializes each object as a single component with maximum flux at the peak that falls off monotonically and has 180 degree symmetry.

In [None]:
sources = [scarlet.ExtendedSource(frame, (c['y'], c['x']), observation) for c in cat]

scarlet.display.show_scene(sources, 
                           norm = norm, 
                           observation=observation, 
                           show_observed=True)
plt.show()



In [None]:
# Display the initial guess for each source
scarlet.display.show_sources(sources,
                             norm=norm,
                             observation=observation,
                             show_rendered=True,
                             show_observed=True,
                             add_boxes = True)
plt.show()

## Exercise

* Experiment with various `Source` classes: `PointSource`, `MultiExtendedSource` (`Extended` source with arg K = 1), `StarletSource`.
* Pick a different source depending on the visual appearance of the objects.
* Later: run scarlet with different initializations.
* For the boldest and bravest: come up with your own sources.

In [None]:
sources = []
for i in range(len(cat)):
    if i in """list of detected images""":
        source = scarlet."""Your favourite kind of source"""
    else:
        source = scarlet.ExtendedSource(frame, (cat[i]['y'], cat[i]['x']), observation) 
    sources.append(source)    
    
scarlet.display.show_scene(sources, 
                           norm = norm, 
                           observation=observation, 
                           show_rendered=True, 
                           show_observed=True, 
                           show_residual=True)
plt.show()

# Deblending a scene

the `Blend` class contains the list of sources, the observations(s) and any other configuration parameters necessary to fit the data.

In [None]:
blend = scarlet.Blend(sources, observation)

Next we can fit a model, given a maximum number of iterations and the relative error required for convergence.

In [None]:
# Fit the data until the relative error is <= 1e-3,
# for a maximum of 200 iterations
%time blend.fit(200, e_rel = 1e-3)
print("scarlet ran for {0} iterations to logL = {1}".format(len(blend.loss), -blend.loss[-1]))
plt.plot(-np.array(blend.loss))
plt.xlabel('Iteration')
plt.ylabel('log-Likelihood')

There are two options for displaying the scene, using `scarlet.display.show_scene` function. This shows the model along with the observation information and the residuals defined as: `observation.images - model`. 

In [None]:
scarlet.display.show_scene(sources, norm=norm,linear=True, 
                           observation=observation, 
                           show_observed=True, 
                           show_rendered=True, 
                           show_residual=True
                          )
plt.show()

Scarlet can perform basiic measurements on sources. Given that sources are isolated and nosiie less, one can for nstace compute the flux of each source in each band.

In [None]:
fluxes = []
for s in sources:
    fluxes.append(scarlet.measure.flux(s))
    
plt.plot(np.array(fluxes).T)
plt.show()

In [None]:
# Display the initial guess for each source
scarlet.display.show_sources(sources,
                             norm=norm,
                             observation=observation,
                             show_rendered=True,
                             show_observed=True,
                             add_boxes = True)
plt.show()

## To go further

Sources as `FactorizedComponents` are made of a `TabulatedSpectrum` that describes the contribution of a source to each band, and a `Morphology` that discribes the 2D profile of the source. Depending on the source one may want to impose different constraints on what the shape of the galaxy may look like. This is done by building a set of constraints associated with the parameters of the morphology.

Let us declare a custom-ish `Source` class, initialize sources with it and see how it runs.

In [None]:
from scarlet import TabulatedSpectrum, FactorizedComponent, init_extended_source, ImageMorphology

class NewSourceMorphology(ImageMorphology):
    def __init__(
        self,
        frame,
        center,
        image,
        bbox=None,
        shifting=False,
    ):
        """Non-parametric image morphology designed for galaxies as extended sources.

        Parameters
        ----------
        frame: `~scarlet.Frame`
            The frame of the full model
        center: tuple
            Center of the source
        image: `numpy.ndarray`
            Image of the source.
        bbox: `~scarlet.Box`
            2D bounding box for focation of the image in `frame`
        monotonic: ['flat', 'angle', 'nearest'] or None
            Which version of monotonic decrease in flux from the center to enforce
        symmetric: `bool`
            Whether or not to enforce symmetry.
        min_grad: float in [0,1)
            Minimal radial decline for monotonicity (in units of reference pixel value)
        shifting: `bool`
            Whether or not a subpixel shift is added as optimization parameter
        """

        constraints = []
        # backwards compatibility: monotonic was boolean
        if monotonic is True:
            monotonic = "angle"
        elif monotonic is False:
            monotonic = None
        if monotonic is not None:
            # most astronomical sources are monotonically decreasing
            # from their center
            constraints.append(
                MonotonicityConstraint(neighbor_weight=monotonic, min_gradient=min_grad)
            )


        constraints += [
            # most astronomical sources are monotonically decreasing
            MonotonicityConstraint(),
            # have 2-fold rotation symmetry around their center ...
            SymmetryConstraint(),
            # ... and are positive emitters
            PositivityConstraint(),
            # prevent a weak source from disappearing entirely
            CenterOnConstraint(),
            # break degeneracies between sed and morphology
            NormalizationConstraint("max"),
        ]
        morph_constraint = ConstraintChain(*constraints)
        image = Parameter(image, name="image", step=1e-2, constraint=morph_constraint)

        self.pixel_center = np.round(center).astype("int")
        if shifting:
            shift = Parameter(center - self.pixel_center, name="shift", step=1e-1)
        else:
            shift = None
        self.shift = shift

        super().__init__(frame, image, bbox=bbox, shift=shift)

    @property
    def center(self):
        if self.shift is not None:
            return self.pixel_center + self.shift
        else:
            return self.pixel_center

        
class NewExtendedSource(FactorizedComponent):
    def __init__(
        self,
        model_frame,
        sky_coord,
        observations,
        coadd=None,
        coadd_rms=None,
        thresh=1.0,
        compact=False,
        shifting=False,
    ):
        """Extended source model

        The model is initialized from `observations` with a symmetric and
        monotonic profile and a spectrum from its peak pixel.

        During optimization it enforces positivitiy for spectrum and morphology,
        as well as monotonicity of the morphology.

        Parameters
        ----------
        model_frame: `~scarlet.Frame`
            The frame of the full model
        sky_coord: tuple
            Center of the source
        observations: instance or list of `~scarlet.observation.Observation`
            Observation(s) to initialize this source.
        coadd: `numpy.ndarray`
            The coaddition of all images across observations.
        coadd_rms: float
            Noise level of the coadd
        thresh: `float`
            Multiple of the backround RMS used as a
            flux cutoff for morphology initialization.
        compact: `bool`
            Initialize with the shape of a point source
        shifting: `bool`
            Whether or not a subpixel shift is added as optimization parameter
        """
        # initialize from observation
        spectrum, morph, bbox = init_extended_source(
                sky_coord,
                model_frame,
                observations,
                coadd,
                coadd_rms=coadd_rms,
                thresh=thresh,
                compact=compact,
                symmetric=True,
                monotonic="flat",
                min_grad=0,
            )

        spectrum = TabulatedSpectrum(model_frame, spectrum, bbox=bbox[0])

        center = model_frame.get_pixel(sky_coord)
        morphology = NewSourceMorphology(
            model_frame,
            center,
            morph,
            bbox=bbox[1:],
            shifting=shifting,
        )
        self.center = morphology.center
        super().__init__(model_frame, spectrum, morphology)

# Exercise

Your turn now, deblend this: Using the detection catalog I provided:

In [None]:
#Fetching images
cutout_size = 300
cutout_extent = lsst.geom.ExtentI(cutout_size, cutout_size)
skymap = butler.get('deepCoadd_skyMap')
radec = lsst.geom.SpherePoint(56.811321, -31.123123, lsst.geom.degrees)
center = skymap.findTract(radec).getWcs().skyToPixel(radec)
bbox = lsst.geom.BoxI(lsst.geom.Point2I((center.x - cutout_size*0.5, center.y - cutout_size*0.5)), cutout_extent)

cutouts = [butler.get("deepCoadd_sub", bbox=bbox, tract=4639, patch='1,0', filter=band) for band in "ugrizy"]
coadds = MultibandExposure.fromExposures("ugrizy", cutouts)

# PSF
psfs = coadds.computePsfImage(Point2I(center.x - cutout_size*0.5, center.y - cutout_size*0.5)).array
#Cube image
cube = coadds.image.array
#Weights
var2 = coadds.variance.array
weights2 = 1 / (var ** 2)

Here is the catalog of detections along with the image of the patch. 
Note here that the catalog was generated by running sep directly on the images, without using wavelet filtering.
Wavelet filtering actually causes the substructures to be detected.

In [None]:
#Open cat
cat = pickle.load(open('cat2.pkl', 'rb'))
#Display image
imshow_rgb(cube, figsize = (20,20), cat = cat)

## Task 1

Setup the frame and observation and match them:

In [None]:
# Create a PSF image of a narrow gaussian to use as our image PSF
channels = "..."
# Create a model psf using `scarlet.PSF` object
model_psf = "..."

# Create the initial frame (metadata for the model).
frame = scarlet.frame("...")

# Create our observation
observation = scarlet.Observation("...").match(frame)

## Task 2

Initialize the sources and display the intial model. 

Once your first try of running scarlet on the patch over, try changing the sources you used.

In [None]:
"""You're on your own here, you'll have to declare the sources
...
"""

scarlet.display.show_scene(sources, 
                           norm = norm, 
                           observation=observation, 
                           show_observed=True)
plt.show()

##  Task 3

Run scarlet and display the results. Suggestion, play with the e_rel parameter to see how it improves the fit. e_rel is the criteria of convergence. The smaller e_rel the longer the algorithm will run until converged.

In [None]:
blend = scarlet.Blend(sources, observation)
%time blend.fit(200, e_rel = 1e-3)
print("scarlet ran for {0} iterations to logL = {1}".format(len(blend.loss), -blend.loss[-1]))
plt.plot(-np.array(blend.loss))
plt.xlabel('Iteration')
plt.ylabel('log-Likelihood')

In [None]:
"""Display 
...
"""