# Multi-Resolution Modeling

This tutorial shows how to model sources frome images observed with different telescopes. We will use a multiband observation with the Hyper-Sprime Cam (HSC) and a single high-resolution image from the Hubble Space Telescope (HST).

In [1]:
# Import Packages and setup
import numpy as np
import scarlet
import scarlet.display
import astropy.io.fits as fits
from astropy.wcs import WCS
from scarlet.display import AsinhMapping
from scarlet import Starlet
import scipy.stats as scs

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
# use a better colormap and don't interpolate the pixels
matplotlib.rc('image', cmap='gist_stern')
matplotlib.rc('image', interpolation='none')

ImportError: cannot import name 'Starlet' from 'scarlet' (/Users/remy/Desktop/LSST_Project/scarlet/scarlet/__init__.py)

## Load and Display Data

We first load the HSC and HST images, swapping the byte order if necessary because a bug in astropy does not respect the local endianness.

In [None]:
# Load the HSC image data
obs_hdu = fits.open('/Users/remy/Desktop/LSST_Project/scarlet/data/test_resampling/Cut_HSC.fits')
data_hsc = obs_hdu[0].data.byteswap().newbyteorder()
wcs_hsc = WCS(obs_hdu[0].header)
channels_hsc = ['g','r','i','z','y']

# Load the HSC PSF data
psf_hsc = fits.open('/Users/remy/Desktop/LSST_Project/scarlet/data/test_resampling/PSF_HSC.fits')[0].data
Np1, Np2 = psf_hsc[0].shape
psf_hsc = scarlet.PSF(psf_hsc)

# Load the HST image data
hst_hdu = fits.open('/Users/remy/Desktop/LSST_Project/scarlet/data/test_resampling/Cut_HST.fits')
data_hst = hst_hdu[0].data
wcs_hst = WCS(hst_hdu[0].header)
channels_hst = ['F814W']

# apply wcs correction
#wcs_hsc.wcs.crval -= np.array([0,2.22364861e-05, 7.9102005e-06])
#wcs_hst.wcs.crval -= np.array([1.49868145e-06, 1.25988294e-06])
wcs_hst.wcs.crval -= np.array([5.62908748e-05, -3.96986249e-06])
wcs_hsc.wcs.crval -= np.array([0,3.70488632e-06, 4.3840208e-06])
# Load the HST PSF data
psf_hst = fits.open('/Users/remy/Desktop/LSST_Project/scarlet/data/test_resampling/PSF_HST.fits')[0].data
psf_hst = psf_hst[None,:,:]
psf_hst = scarlet.PSF(psf_hst)

# Scale the HST data
n1,n2 = np.shape(data_hst)
data_hst = data_hst.reshape(1, n1, n2).byteswap().newbyteorder()

r, N1, N2 = data_hsc.shape

Next we have to create a source catalog for the images. We'll use `sep` for that, but any other detection method will do. Since HST is higher resolution and less affected by blending, we use it for detection but we also run detection on the HSC image to calculate the background RMS:

In [None]:
import sep

class Data():
    
    def __init__(self, images, wcss, psfs, channels):
        self.images = images
        self.wcs = wcss
        self.psfs = psfs.image
        self.channels = channels

def interpolate(data_lr, data_hr):
    #Interpolate low resolution data to high resolution
    coord_lr0 = (np.arange(data_lr.images.shape[1]), np.arange(data_lr.images.shape[1]))
    coord_hr = (np.arange(data_hr.images.shape[1]), np.arange(data_hr.images.shape[1]))
    coord_lr = scarlet.resampling.convert_coordinates(coord_lr0, data_lr.wcs, data_hr.wcs)
    
    interp = []
    for image in data_lr.images:
        interp.append(scarlet.interpolation.sinc_interp(image[None, :,:], coord_hr, coord_lr, angle=None)[0].T)
    return np.array(interp)
        
def makeCatalog_multi(data_lr, data_hr, lvl = 3, wave = True):
    #Create observations for each image
    #Interpolate low resolution to high resolution
    interp = interpolate(data_lr, data_hr)
    #Normalisation of the interpolate low res images
    interp = interp/np.sum(interp, axis = (1,2))[:,None, None]
    #Normalisation of the high res data
    hr_images = data_hr.images/np.sum(data_hr.images, axis = (1,2))[:,None, None]
    #Detection image as the sum over all images
    detect_image = np.sum(interp, axis = 0) + np.sum(hr_images, axis = 0)
    detect_image *= np.sum(data_hr.images)
    plt.imshow((detect_image))
    plt.show()
    if np.size(detect_image.shape) == 3:
        if wave:
            #Wavelet detection in the first three levels
            wave_detect = Starlet(detect_image.mean(axis=0)).starlet
            detect = wave_detect[0][0] + wave_detect[0][1] +  wave_detect[0][2] 
        else:
            #Direct detection
            detect = detect_image.mean(axis=0)
    else:
        if wave:
            wave_detect = Starlet(detect_image).starlet
            detect = wave_detect[0][0] + wave_detect[0][1] +  wave_detect[0][2] 
        else:
            detect = detect_image
    

    bkg = sep.Background(detect)
    catalog = sep.extract(detect, lvl, err=bkg.globalrms)
    bg_rms = []
    for data in datas:
        img = data.images
        if np.size(img.shape) == 3:
            bg_rms.append(np.array([scs.median_absolute_deviation(Starlet(band).starlet[0][0].flatten()) for band in img]))
        else:
            bg_rms.append(scs.median_absolute_deviation(Starlet(img).starlet[0][0].flatten()))
 
    return catalog, np.array(bg_rms), detect_image

def makeCatalog(data, lvl = 3, wave = True):
    #Normalisation of the data
    hr_images = data/np.sum(data, axis = (1,2))[:,None, None]
    #Detection image as the sum over all images
    detect_image = np.sum(hr_images, axis = 0)
    
    plt.imshow((detect_image))
    plt.show()
    if np.size(detect_image.shape) == 3:
        if wave:
            #Wavelet detection in the first three levels
            wave_detect = Starlet(detect_image.mean(axis=0)).starlet
            detect = wave_detect[0][0] + wave_detect[0][1] +  wave_detect[0][2] 
        else:
            #Direct detection
            detect = detect_image.mean(axis=0)
    else:
        if wave:
            wave_detect = Starlet(detect_image).starlet
            detect = wave_detect[0][0] + wave_detect[0][1] +  wave_detect[0][2] 
        else:
            detect = detect_image
    

    bkg = sep.Background(detect)
    catalog = sep.extract(detect, lvl, err=bkg.globalrms)
    bg_rms = []
    img = data
    if np.size(img.shape) == 3:
        bg_rms.append(np.array([scs.median_absolute_deviation(Starlet(band).starlet[0][0].flatten()) for band in img]))
    else:
        bg_rms.append(scs.median_absolute_deviation(Starlet(img).starlet[0][0].flatten()))
 
    return catalog, np.array(bg_rms), detect_image

In [None]:
data_hr =  Data(data_hst, wcs_hst, psf_hst, channels_hst)
data_lr =  Data(data_hsc, wcs_hsc, psf_hsc, channels_hsc)
datas = [data_lr, data_hr]

wave = 1
lvl = 3
catalog_multi, bg_rms_multi, detect_multi = makeCatalog_multi(data_lr, data_hr, lvl, wave)

catalog_hsc, bg_rms_hsc, detect_hsc = makeCatalog(data_hsc, lvl, wave)
catalog_hst, bg_rms_hst, detect_hst = makeCatalog(data_hst, lvl, wave)

weights_hst = np.ones_like(data_hst) / (bg_rms_multi[1]**2)[:, None, None]
weights_hsc = np.ones_like(data_hsc) / (bg_rms_multi[0]**2)[:, None, None]

Finally we can visualize both the multiband HSC and single band HST images in their native resolutions:

In [None]:
# Create a color mapping for the HSC image
hsc_norm = AsinhMapping(minimum=-1, stretch=2, Q=10)
hst_norm = AsinhMapping(minimum=-1, stretch=10, Q=5)

# Get the source coordinates from the HST catalog
xt,yt = catalog_hst['x'], catalog_hst['y']
xm,ym = catalog_multi['x'], catalog_multi['y']
xc,yc = catalog_hsc['x'], catalog_hsc['y']
# Convert the HST coordinates to the HSC WCS
rat, dect = wcs_hst.wcs_pix2world(yt,xt,0)
ram, decm = wcs_hst.wcs_pix2world(ym,xm,0)
rac, decc, _ = wcs_hsc.wcs_pix2world(yc,xc,0, 0)

Yt,Xt, l = wcs_hsc.wcs_world2pix(rat, dect, 0, 0)
Ym,Xm, l = wcs_hsc.wcs_world2pix(ram, decm, 0, 0)
Yc,Xc = wcs_hst.wcs_world2pix(rac, decc, 0)
# Map the HSC image to RGB
img_rgb = scarlet.display.img_to_rgb(data_hsc, norm=hsc_norm)
# Apply Asinh to the HST data
hst_img = scarlet.display.img_to_rgb(data_hst, norm=hst_norm)

plt.figure(figsize=(15,30))
plt.subplot(121)
plt.imshow(img_rgb)
plt.plot(xc,yc, 's')
plt.plot(Xt,Yt, 'og')
plt.plot(Xm,Ym, 'xr')

plt.subplot(122)
plt.imshow(hst_img)
plt.plot(Xc,Yc, 's')
plt.plot(xt,yt, 'og')
plt.plot(xm,ym, 'xr')
plt.show()

## Create Frame and Observations

Unlike the single resolution examples, we now have two different instruments with different pixel resolutions, so we need two different observations. Since the HST image is at a much higher resolution, we define our model `Frame` to use the HST PSF and the HST resolution. Because there is no resampling between the model frame and the HST observation, we can use the default `Observation` class for the HST data. The HSC images have lower resolution, so we need to resample the models to this frame, and that's done by `LowResObservation`.

Alternatively, it is possible to define a frame automatically from a set of observations. In this case, the user does not have to know which observation needs to be a `LowResObservation`. Instead, method `frome_observation` creates a frame that encapsluates, either the union or the intersection of a set of observations, and defines a frame based on the highest resolution available between all observations.

In [None]:
#Automated frame definition
# define two observation packages and match to frame
multi_hst = scarlet.Observation(data_hst, wcs=wcs_hst, psfs=psf_hst, channels=channels_hst, weights=weights_hst)
multi_hsc = scarlet.Observation(data_hsc, wcs=wcs_hsc, psfs=psf_hsc, channels=channels_hsc, weights=weights_hsc)

# Keep the order of the observations consistent with the `channels` parameter
# This implementation is a bit of a hack and will be refined in the future
obs = [multi_hsc, multi_hst]
frame = scarlet.Frame.from_observations(obs, coverage = 'intersection')
multi_hsc, multi_hst = obs

obs_detect = scarlet.Observation(detect_multi[None,:,:], wcs=wcs_hst, psfs=psf_hst, channels=channels_hst)
obs_detects = [multi_hsc, obs_detect]
frame_detect = scarlet.Frame.from_observations(obs_detects, coverage = 'intersection')
obs_detect = obs_detects[1]

In this experiment, we are going to compare the fit to the data using HSC data alone, HST data alone and the multi-reolution framework applied to the combination of HST and HSC. We need to build a frame for each of these cases:

In [None]:
obs_hst = scarlet.Observation(data_hst, wcs=wcs_hst, psfs=psf_hst, channels=channels_hst, weights=weights_hst)
obs_hsc = scarlet.Observation(data_hsc, wcs=wcs_hsc, psfs=psf_hsc, channels=channels_hsc, weights=weights_hsc)

HSC_frame = scarlet.Frame(
    data_hsc.shape,
    wcs = wcs_hsc,
    psfs=psf_hsc,
    channels=channels_hsc)
obs_hsc.match(HSC_frame)

HST_frame = scarlet.Frame(
    data_hst.shape,
    wcs = wcs_hst,
    psfs=psf_hst,
    channels=channels_hst)
obs_hst.match(HST_frame)


## Initialize Sources and Blend

We expect all sources to be galaxies, so we initialized them as `ExtendedSources`. Because the initialization takes a list of observations, we set the `obs_idx` argument to state which observation in the list of observations is used to initialize the morphology.

`Blend` will hold a list of all sources and *all* observations to fit.

In [None]:
#Detection with multiple observations
multi_sources = [
    scarlet.ExtendedSource(frame_detect, (ram[i], decm[i]), obs_detects, 
                           symmetric=False, 
                           monotonic=True, 
                           obs_idx=1)
    for i in range(ram.size)
]

blend_multi = scarlet.Blend(multi_sources, obs)

In [None]:
#Detection HSC only
HSC_sources = [
    scarlet.ExtendedSource(HSC_frame, (rac[i], decc[i]), obs_hsc, 
                           symmetric=False, 
                           monotonic=True)
    for i in range(rac.size)
]

blend_hsc = scarlet.Blend(HSC_sources, obs_hsc)

In [None]:
#Detection HST only
HST_sources = [
    scarlet.ExtendedSource(HST_frame, (rat[i], dect[i]), obs_hst, 
                           symmetric=False, 
                           monotonic=True)
    for i in range(rat.size)
]

blend_hst = scarlet.Blend(HST_sources, obs_hst)

In [None]:
#Detection HSC only from multi catalog
HSC_sources = [
    scarlet.ExtendedSource(HSC_frame, (ram[i], decm[i]), obs_hsc, 
                           symmetric=False, 
                           monotonic=True)
    for i in range(ram.size)
]

blend_multihsc = scarlet.Blend(HSC_sources, obs_hsc)

In [None]:
#Detection HST only from multi catalog
HST_sources = [
    scarlet.ExtendedSource(HST_frame, (ram[i], decm[i]), obs_hst, 
                           symmetric=False, 
                           monotonic=True)
    for i in range(ram.size)
]

blend_multihst = scarlet.Blend(HST_sources, obs_hst)

## Display Initial guess

Let's compare the initial guess of the model in both model frame and HSC observation frame:

In [None]:
def display_init(blend, obs, data, ids, img, title):
    # Load the model and calculate the residual
    model = blend.get_model()
    model = obs.render(model)
    
    init_rgb = scarlet.display.img_to_rgb(model, norm=hsc_norm)
    residual = data - model
    
    residual_rgb = scarlet.display.img_to_rgb(residual[:,:-5])
    
    plt.figure(ids,figsize=(15, 10))
    plt.suptitle(title, fontsize=36)
    plt.subplot(131)
    plt.imshow(img)
    plt.title("HSC data")
    plt.subplot(132)
    plt.imshow(init_rgb)
    plt.title("HighRes Model")
    plt.subplot(133)
    plt.imshow(residual_rgb)
    plt.title("LowRes Model")
    pass

display_init(blend_hsc, obs_hsc, data_hsc, 0, img_rgb, 'HSC initialisation')
display_init(blend_hst, obs_hst, data_hst, 1, data_hst[0], 'HST initialisation')
display_init(blend_multihsc, obs_hsc, data_hsc, 2, img_rgb, 'HSC-multi initialisation')
display_init(blend_multihst, obs_hst, data_hst, 3, data_hst[0], 'HST-multi initialisation')

In [None]:
# Load the model and calculate the residual
model_multi = blend_multi.get_model()

model_lr = multi_hsc.render(model_multi)
init_rgb = scarlet.display.img_to_rgb(model_multi, norm=hsc_norm)
init_rgb_lr = scarlet.display.img_to_rgb(model_lr, norm=hsc_norm)
residual_lr = data_hsc - model_lr
# Trim the bottom source not part of the blend from the image
residual_lr_rgb = scarlet.display.img_to_rgb(residual_lr[:,:-5])

# Get the HR residual
residual_hr = (data_hst - multi_hst.render(model_multi))[0]
vmax = residual_hr.max()

plt.figure(figsize=(15, 10))

plt.subplot(231)
plt.suptitle('Multi-resolution initialisation', fontsize=36)
plt.imshow(img_rgb)
plt.title("HSC data")
plt.subplot(235)
plt.imshow(init_rgb)
plt.title("HighRes Model")
plt.subplot(232)
plt.imshow(init_rgb_lr)
plt.title("LowRes Model")
plt.subplot(236)
plt.imshow(residual_hr, cmap="seismic", vmin=-vmax, vmax=vmax)
plt.colorbar(fraction=.045)
plt.title("HST residual")
plt.subplot(233)
plt.imshow(residual_lr_rgb)
plt.title("HSC residual")
plt.subplot(234)
plt.imshow(hst_img)
plt.colorbar(fraction=.045)
plt.title('HST data')
plt.show()

## Fit Models

In [None]:
blend_hst.fit(100, e_rel = 1.e-7)
print("scarlet ran for {0} iterations to logL = {1}".format(len(blend_hst.loss), -blend_hst.loss[-1]))
plt.plot(-np.array(blend_hst.loss))
plt.xlabel('Iteration')
plt.ylabel('log-Likelihood')

In [None]:
blend_hsc.fit(100, e_rel = 1.e-7)
print("scarlet ran for {0} iterations to logL = {1}".format(len(blend_hsc.loss), -blend_hsc.loss[-1]))
plt.plot(-np.array(blend_hsc.loss))
plt.xlabel('Iteration')
plt.ylabel('log-Likelihood')

In [None]:
blend_multihst.fit(100, e_rel = 1.e-7)
print("scarlet ran for {0} iterations to logL = {1}".format(len(blend_multihst.loss), -blend_multihst.loss[-1]))
plt.plot(-np.array(blend_multihst.loss))
plt.xlabel('Iteration')
plt.ylabel('log-Likelihood')

In [None]:
blend_multihsc.fit(100, e_rel = 1.e-7)
print("scarlet ran for {0} iterations to logL = {1}".format(len(blend_multihsc.loss), -blend_multihsc.loss[-1]))
plt.plot(-np.array(blend_multihsc.loss))
plt.xlabel('Iteration')
plt.ylabel('log-Likelihood')

In [None]:
%time blend_multi.fit(300, e_rel = 1.e-7)
print("scarlet ran for {0} iterations to logL = {1}".format(len(blend_multi.loss), -blend_multi.loss[-1]))
plt.plot(-np.array(blend_multi.loss))
plt.xlabel('Iteration')
plt.ylabel('log-Likelihood')

### View Full Model
First we load the model for the entire blend and its residual. Then we display the model using the same $sinh^{-1}$ stretch as the full image and a linear stretch for the residual to see the improvement from our initial guess.

In [None]:
display_init(blend_hsc, obs_hsc, data_hsc, 0, img_rgb)
display_init(blend_hst, obs_hst, data_hst, 1, data_hst[0])
display_init(blend_multihsc, obs_hsc, data_hsc, 2, img_rgb)
display_init(blend_multihst, obs_hst, data_hst, 3, data_hst[0])

In [None]:
model_multi = blend_multi.get_model()
model_hr = multi_hst.render(model_multi)
model_lr = multi_hsc.render(model_multi)

rgb = scarlet.display.img_to_rgb(model_multi[:-1], norm=hsc_norm)
rgb_lr = scarlet.display.img_to_rgb(model_lr, norm=hsc_norm)
residual_lr = data_hsc - model_lr

# Trim the bottom source not part of the blend from the image
residual_lr_rgb = scarlet.display.img_to_rgb(residual_lr[:,:-5])

# Get the HR residual
residual_hr = (data_hst - model_hr)[0]
vmax = residual_hr.max()

plt.figure(figsize=(15, 10))
plt.subplot(231)
plt.imshow(img_rgb)
plt.title("HSC data")
plt.subplot(235)
plt.imshow(rgb)
plt.title("HST Model")
plt.subplot(232)
plt.imshow(rgb_lr)
plt.title("HSC Model")
plt.subplot(236)
plt.imshow(residual_hr, cmap="seismic", vmin=-vmax, vmax=vmax)
plt.colorbar(fraction=.045)
plt.title("HST residual")
plt.subplot(233)
plt.imshow(residual_lr_rgb)
plt.title("HSC residual")
plt.subplot(234)
plt.imshow(hst_img)
plt.title('HST data')
plt.show()

### View Source Models
It can also be useful to view the model for each source. For each source we extract the portion of the image contained in the sources bounding box, the true simulated source flux, and the model of the source, scaled so that all of the images have roughly the same pixel scale.

In [None]:
has_truth = False
axes = 2

for k,src in enumerate(blend_multi.sources):
    print('source number ', k)
    # Get the model for a single source
    model_multi = src.get_model()
    model_multihsc = blend_multihsc[k].get_model()
    model_multihst = blend_multihst[k].get_model()
    
    model_lr = multi_hsc.render(model_multi)
    model_hr = multi_hst.render(model_multi)

    model_multihsc = obs_hsc.render(model_multihsc)
    model_multihst = obs_hst.render(model_multihst)
    
    # Display the low resolution image and residuals
    img_lr_rgb = scarlet.display.img_to_rgb(model_lr, norm = hsc_norm)
    res_lr = data_hsc-model_lr
    
    plt.figure(figsize=(40,10))
    
    plt.subplot(141)
    plt.imshow(img_rgb)
    plt.plot(Xm[k],Ym[k], 'o', markersize = 10)
    plt.title("HSC Data", fontsize = 30)
    plt.subplot(142)
    plt.imshow(model_hr[-1])
    plt.title("Multi resolution model", fontsize = 30)
    plt.subplot(143)
    plt.imshow(model_multihsc[-1])
    plt.title("HSC Model", fontsize = 30)
    plt.subplot(144)
    plt.imshow(model_multihst[-1])
    plt.title("HST Model", fontsize = 30)
    plt.show()