# Deblending with *Scarlet*
<br>Owner(s): **Fred Moolekamp** ([@fred3m](https://github.com/LSSTScienceCollaborations/StackClub/issues/new?body=@fred3m))
<br>Last Verified to Run: **2020-07-10**
<br>Verified Stack Release: **v20.0.0**

The purpose of this tutorial is to familiarize you with the basics of using *scarlet* to model blended scenes, and how tweaking various objects and parameters affects the resulting model. A tutorial that is more specific to using scarlet in the context of the LSST DM Science Pipelines is also available.

### Learning Objectives:

After working through this tutorial you should be able to: 
1. Configure and run _scarlet_ on a test list of objects;
2. Understand its various model assumptions and applied constraints.
3. Use specific configurations to fit objects of different nature (stars, galaxies, LSBG)
4. Bonus: we would like to give users a sense of how they can use their own assumptions to build models in scarlet

Before attempting this tutorial it will be useful to read the [introduction](https://fred3m.github.io/scarlet/user_docs.html) to the *scarlet* User Guide, and many of the exercises below may require referencing the *scarlet* [docs](https://fred3m.github.io/scarlet/).

### Logistics
This notebook is intended to be runnable on `cori.nersc.gov` from a local git clone of https://github.com/LSSTDESC/StackClub.

## Set-up

In [None]:
# What version of the Stack are we using?
! echo $HOSTNAME
! eups list -s | grep lsst_distrib

In [None]:
# Import the necessary libraries
import os

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
# don't interpolate the pixels
matplotlib.rc('image', interpolation='none')

import numpy as np
from astropy.visualization.lupton_rgb import AsinhMapping

import scarlet
import scarlet.display
from astropy.visualization.lupton_rgb import AsinhMapping, LinearMapping

We will also load the butler and various lsst packages

In [None]:
import numpy as np
from lsst.daf.persistence import Butler
from lsst.geom import Box2I, Box2D, Point2I, Point2D, Extent2I, Extent2D
from lsst.afw.image import Exposure, Image, PARENT, MultibandExposure, MultibandImage
from lsst.afw.detection import MultibandFootprint
from lsst.afw.image import MultibandExposure

# Load and Display the data

More information are provided in the `lsst_stack_deblender.ipynb` tutorial. 

The **butler** is used to recover data from DESC DC2 DR6 by specifying the tract, patch and filters.

In [None]:
import desc_dc2_dm_data
butler = desc_dc2_dm_data.get_butler("2.2i_dr6_wfd")
dataId = {"tract": 3830, "patch": "4,4"}
filters = "ugrizy"
coadds = [butler.get("deepCoadd_calexp", dataId, filter=f) for f in filters]
coadds = MultibandExposure.fromExposures(filters, coadds)

We then display the patch of data using display functions built-in in scarlet. The `norm` is used to create a colour scaling that avoids whitening the center of bright-ish objects also known as _Luptonisation_. 

The `image_to_rgb` function maps multi-band arrays into 3-channel RGB images. The mapping is done automatically by default, assuming an ordering of bands from bluer to redder, but it can be customised using the `channel_map` argument.

In [None]:
norm = scarlet.display.AsinhMapping(minimum=0, stretch=1, Q=10)
rgb_patch = scarlet.display.img_to_rgb(coadds.image.array, norm=norm)

In [None]:
plt.figure(figsize = (30,30))
plt.imshow(rgb_patch, origin = "lower")
plt.xticks(fontsize = 30)
plt.yticks(fontsize = 30)
plt.show()

Let's formalise the previous display procedure into a function

In [None]:
def imshow_rgb(image, norm = None, figsize = None):
    """
    Parameters
    ----------
    image: `numpy.ndarrray`
        Multiband image to display
    norm: `scarlet.display.AsinhMapping`
    """
    if norm == None:
        norm = scarlet.display.AsinhMapping(minimum=0, stretch=1, Q=10)
    rgb_patch = scarlet.display.img_to_rgb(image, norm=norm)
    if figsize == None:
        figsize = (30,30)
        
    plt.figure(figsize = figsize)
    plt.imshow(rgb_patch, origin = "lower")
    plt.xticks(fontsize = 30)
    plt.yticks(fontsize = 30)
    plt.show()
    pass

## Display a subset

Here we extract a patch from the previous image to run scarlet on.

In [None]:
n = 180
sampleBBox = Box2I(Point2I(16880, 19320), Extent2I(n, n))

subset = coadds[:, sampleBBox]
# Due to a bug in the code the PSF isn't copied properly.
# The code below copies the PSF into the `MultibandExposure`,
# but will be unecessary in the future
for f in subset.filters:
    subset[f].setPsf(coadds[f].getPsf())

In [None]:
imshow_rgb(subset.image.array)

# Display the psfs

In [None]:
psfs = subset.computePsfImage(Point2I(16880, 19320)).array
psf_norm = scarlet.display.AsinhMapping(minimum=0, stretch=0.001, Q=10)
imshow_rgb(psfs, norm = psf_norm)

# Define the model and observation frames

A `Frame` in scarlet is the metadata that describes where the model lives. It includes the frames shape, wcs (optional), and the PSF (technically optional but strongly recommended). 

The `Observation` defines where the data lives, but also how to go from the model frame to the data. It contains the data themselves as an array but also meta-information such as channels tags, psf (technically optional but strongly recommended), wcs(optional) and weights (optional). The `Observation` needs to be matched to the `Frame` through the `match()` method. 
In scarlet it is possible to deblend scenes that have observations with different instruments that have different resolutions and/or observations that have not been coadded by building a list of `Observation`s. The `Frame` can be automatically built from the list of `Observation`s, however that is outside the scope of this tutorial and the interested reader should be referred to https://fred3m.github.io/scarlet/tutorials/multiresolution.html.

So we will create an initial model `Frame` that uses a narrow gaussian PSF and an `Observation` that consists of multiple bands of an HSC coadded image.

See https://fred3m.github.io/scarlet/user_docs.html#Frame-and-Observation for more on `Frame`s and `Observation`s.

In [None]:
??scarlet.psf

In [None]:
# Create a PSF image of a narrow gaussian to use as our image PSF

channels = [f for f in filters]
from functools import partial
model_psf = partial(scarlet.psf.gaussian, sigma=0.9)
model_psf /= model_psf.sum()
# Make sure that the observation PSF is normalized (otherwise the scaling in PSF matching might be off)
psfs = psfs / psfs.sum(axis=(1, 2))[:, None, None]

# Create the initial frame (metadata for the model).
# Note that we initialized a PSF with shape (Ny, Nx) but a frame
# expects a PSf with shape (bands, Ny, Nx), so we have to
# broadcast the model_psf into an extra dimension
frame = scarlet.Frame(images.shape, psfs=model_psf, channels=filters)

# Create our observation
observation = scarlet.Observation(images, psfs=psfs, channels=filters, weights=weights).match(frame)

# Initializing Sources

Astrophysical objects are modeled in scarlet as a collection of components, where each component has a single SED that is constant over it's morphology (band independent intensity). So a single source might have multiple components, like a bulge and disk, or a single component.

The different classes that inherit from `Source` mainly differ in how they are initialized, and otherwise behave similarly during the optimization routine. This section illustrates the differences between different source initialization classes.

### <span style="color:red"> *WARNING* </span>
Scarlet accepts source positions using the numpy/C++ convention of (y,x), which is different than the astropy and LSST stack convention of (x,y).

Below we demonstrate the usage of `ExtendedSource`, which initializes each object as a single component with maximum flux at the peak that falls off monotonically and has 180 degree symmetry.

In [None]:
sources = [scarlet.ExtendedSource(frame, (peak[1], peak[0]), observation) for peak in peaks]

# Display the initial guess for each source
scarlet.display.show_sources(sources,
                             norm=norm,
                             observation=observation,
                             show_rendered=True,
                             show_observed=True)
plt.show()

## Exercise:

* Experiment with the above code by using ; and using `MultiComponentSource`, which models a source as two components (a bulge and a disk) that are each symmetric and montonically decreasing from the peak.

# Deblending a scene

The `Blend` class contains the list of sources, the observation(s), and any other configuration parameters necessary to fit the data.

In [None]:
blend = scarlet.Blend(sources, observation)

Next we can fit a model, given a maximum number of iterations and the relative error required for convergence.

In [None]:
# Fit the data until the relative error is <= 1e-3,
# for a maximum of 200 iterations
blend = scarlet.Blend(sources, observation)
%time blend.fit(200, e_rel = 1e-3)
print("scarlet ran for {0} iterations to logL = {1}".format(len(blend.loss), -blend.loss[-1]))
plt.plot(-np.array(blend.loss))
plt.xlabel('Iteration')
plt.ylabel('log-Likelihood')

There are two options for displaying the scene, using `scarlet.display.show_scene` function. This shows the model along with the observation information and the residuals defined as: `observation.images - model`. 

In [None]:
scarlet.display.show_scene(sources, norm=norm,linear=True, 
                           observation=observation, show_observed=True, 
                           label_sources=True, 
                           show_rendered=True, 
                           show_residual=True
                          )
plt.show()

You can also do it by hand. 

In [None]:
# Load the model and calculate the residual
model = blend.get_model()
model_ = observation.render(model)  # adapt model to observations. 
residual = images-model_

# Create RGB images
model_rgb = scarlet.display.img_to_rgb(model_, norm=norm)
residual_rgb = scarlet.display.img_to_rgb(residual)

# Show the data, model, and residual
fig = plt.figure(figsize=(15,5))
ax = [fig.add_subplot(1,3,n+1) for n in range(3)]
ax[0].imshow(img_rgb)
ax[0].set_title("Data")
ax[1].imshow(model_rgb)
ax[1].set_title("Model")
ax[2].imshow(residual_rgb)
ax[2].set_title("Residual")

for k,component in enumerate(blend):
    y,x = component.center
    ax[0].text(x, y, k, color="w")
    ax[1].text(x, y, k, color="w")
    ax[2].text(x, y, k, color="w")
plt.show()

## Exercises

* Experiment by running the above code using different source models (for example `MultiComponentSource` or `PointSource`) to see how initializtion affects the belnding results.

* Change the value of `e_rel` in the above fit and try to understand how it affects the results. 