In [None]:
import scarlet
import galsim
import astropy.io.fits as fits
from astropy.wcs import WCS
import time
import proxmin
import pickle
from scarlet import Box


# Import Packages and setup
import numpy as np

import scarlet.display
from scarlet.display import AsinhMapping
from scarlet import Starlet
import scipy.stats as scs
from scarlet.initialization import build_initialization_image, set_spectra_to_match
from scarlet.renderer import ConvolutionRenderer
from functools import partial
from scarlet_extensions.initialization.detection import makeCatalog, Data
from scarlet_extensions.scripts.runner import Runner
from scarlet_extensions.initialization.detection import mad_wavelet
import warnings
warnings.filterwarnings("ignore")

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
# use a better colormap and don't interpolate the pixels
matplotlib.rc('image', cmap='gray')
matplotlib.rc('image', interpolation='none')

In [None]:
class FullMorphology(scarlet.ImageMorphology):
    def __init__(
        self,
        frame,
        center,
        image,
        bbox=None,
        min_grad=0,
        shifting=False,
        resizing=True,
    ):
        """Non-parametric image morphology designed for galaxies as extended sources.

        Parameters
        ----------
        frame: `~scarlet.Frame`
            The frame of the full model
        center: tuple
            Center of the source
        image: `numpy.ndarray`
            Image of the source.
        bbox: `~scarlet.Box`
            2D bounding box for focation of the image in `frame`
        monotonic: ['flat', 'angle', 'nearest'] or None
            Which version of monotonic decrease in flux from the center to enforce
        symmetric: `bool`
            Whether or not to enforce symmetry.
        min_grad: float in [0,1)
            Minimal radial decline for monotonicity (in units of reference pixel value)
        shifting: `bool`
            Whether or not a subpixel shift is added as optimization parameter
        resize: bool
            Whether to resize the box dynamically
        """

        image = scarlet.Parameter(image, name="image", step=1e-2, constraint=scarlet.PositivityConstraint())

        self.pixel_center = np.round(center).astype("int")
        if shifting:
            shift = scarlet.Parameter(center - self.pixel_center, name="shift", step=1e-1)
        else:
            shift = None
        self.shift = shift

        super().__init__(
            frame, image, bbox=bbox, shifting=shifting, shift=shift, resizing=resizing
        )

    @property
    def center(self):
        if self.shift is not None:
            return self.pixel_center + self.shift
        else:
            return self.pixel_center


class FullSource(scarlet.FactorizedComponent):
    def __init__(
        self,
        model_frame,
        sky_coord,
        observations,
        thresh=1.0,
        shifting=False,
        resizing=True,
        spectra = None
    ):
        """Extended source model

        The model is initialized from `observations` with a symmetric and
        monotonic profile and a spectrum from its peak pixel.

        During optimization it enforces positivitiy for spectrum and morphology,
        as well as monotonicity of the morphology.

        Parameters
        ----------
        model_frame: `~scarlet.Frame`
            The frame of the model
        sky_coord: tuple
            Center of the source
        observations: instance or list of `~scarlet.observation.Observation`
            Observation(s) to initialize this source.
        thresh: `float`
            Multiple of the backround RMS used as a
            flux cutoff for morphology initialization.
        compact: `bool`
            Initialize with the shape of a point source
        shifting: `bool`
            Whether or not a subpixel shift is added as optimization parameter
        resizing : bool
            Whether or not to change the size of the source box.
        boxsize: int or None
            Spatial size of the source box
        """
        if not hasattr(observations, "__iter__"):
            observations = (observations,)

        # get center pixel spectrum
        # this is from convolved image: weighs higher emission *and* narrow PSF
        if spectra is None:
            spectra = []
            for obs in observations:
                spectra.append(numpy.eye(obs.C))[int(np.random.rand(1)*obs.C)]
        insert_box = observations[1].bbox.__copy__()

        insert_box.__iadd__(16)
        
        morph = np.zeros(model_frame.shape[1:])#+np.random.rand(*model_frame.shape[1:])*0.1
        morph[insert_box.slices[1:]] =  observations[1].data[0]
        morph[morph<5*mad_wavelet(morph[None,:])]=0
        bbox = Box(morph.shape)
        
        center = (model_frame.shape[-2]/2, model_frame.shape[-1]/2)
        morphology = FullMorphology(
            model_frame,
            center,
            morph,
            bbox=bbox,
            min_grad=0,
            shifting=shifting,
            resizing=resizing,
        )

        # find best-fit spectra for morph from init coadd
        # assumes img only has that source in region of the box
        detect_all, std_all = scarlet.initialization.build_initialization_image(observations)
        box_3D = Box((model_frame.C,)) @ bbox
        boxed_detect = box_3D.extract_from(detect_all)
        spectrum = scarlet.initialization.get_best_fit_spectrum((morph,), boxed_detect)
        noise_rms = np.concatenate(
            [np.array(np.mean(obs.noise_rms, axis=(1, 2))) for obs in observations]
        ).reshape(-1)
        spectrum = scarlet.TabulatedSpectrum(model_frame, spectra, min_step=noise_rms)

        # set up model with its parameters
        super().__init__(model_frame, spectrum, morphology)

        # retain center as attribute
        self.center = morphology.center

    

In [None]:

path = './Multi-resolution/COSMOS0211/'
obs_hdu = fits.open(path + 'cut_0001_150.54667000_2.19417000_acs_I_mosaic_30mas_sci.fits')
im_hst = obs_hdu[0].data.byteswap().newbyteorder()
wcs_hst = WCS(obs_hdu[0].header)
channels_hst = ['f814w']


# Load for HST image data one band only
hsc_hdu = fits.open(path + 'cube_HSC_8-cutout-HSC-Y-9813-pdr2_dud.fits')
data_hsc = hsc_hdu[0].data
wcs_hsc = WCS(hsc_hdu[0].header)
im_hsc = data_hsc

# Load the HST PSF data
psf_hst = fits.open(path+'PSF_hst_0001_150.54667000_2.19417000_acs_I_mosaic_30mas_sci.fits')[0].data
psf_hst = psf_hst[None,:,:]
psf_hst = scarlet.ImagePSF(psf_hst)

# Load the HSC PSF data
psf_hsc = fits.open('./Multi-resolution/psf_0211/cube_HSC_7-psf-calexp-pdr2_dud-HSC-Z-9813-2,4-150.54667-2.19417.fits')[0].data
Np1, Np2 = psf_hsc[0].shape
psf_hsc = scarlet.ImagePSF(psf_hsc)

channels_hsc = ['G', 'R', 'I', 'Z', 'Y'] 

channels = channels_hsc+channels_hst 
# Scale the HST data
n1,n2 = np.shape(im_hst)
im_hst = im_hst[None, :, :]

#Data
data_hst =  Data(im_hst, wcs_hst, psf_hst, channels_hst)
data_hsc =  Data(im_hsc, wcs_hsc, psf_hsc, channels_hsc)
datas = [data_hsc, data_hst]

In [None]:
# define two observation packages and match to frame
obs_hst = scarlet.Observation(im_hst, 
                              wcs=wcs_hst, 
                              psf=psf_hst, 
                              channels=channels_hst, 
                              weights=None)

obs_hsc = scarlet.Observation(im_hsc, 
                              wcs=wcs_hsc, 
                              psf=psf_hsc, 
                              channels=channels_hsc, 
                              weights=None)

observations = [obs_hsc, obs_hst]
model_psf = scarlet.GaussianPSF(sigma=[[.3, .3]])
model_frame = scarlet.Frame.from_observations(observations, coverage='union', model_psf=model_psf)
obs_hsc, obs_hst = observations

obs_hst.match(model_frame, renderer=ConvolutionRenderer(obs_hst, model_frame, 
                                                         convolution_type="fft", 
                                                         psf_shift=np.array([-2.,-1.])))

In [None]:
wave = 1
lvl =2
thresh = 3

import sep
        
def makeThisCatalog(obs_lr, obs_hr, lvl = 3, wave = True):
    # Create a catalog of detected source by running SEP on the wavelet transform 
    # of the sum of the high resolution images and the low resolution images interpolated to the high resolution grid
    #Interpolate LR to HR
    interp = scarlet.interpolation.interpolate_observation(obs_lr, obs_hr)
    # Normalisation 
    interp = interp/np.sum(interp, axis = (1,2))[:,None, None]
    hr_images = obs_hr.data/np.sum(obs_hr.data, axis = (1,2))[:,None, None]
    # Summation to create a detection image
    detect_image = np.sum(interp, axis = 0) + np.sum(hr_images, axis = 0)
    # Rescaling to HR image flux
    detect_image *= np.sum(obs_hr.data)
    # Wavelet transform
    wave_detect = scarlet.Starlet.from_image(detect_image).coefficients
    
    if wave:
        # Creates detection from the first 3 wavelet levels
        detect = wave_detect[:lvl,:,:].sum(axis = 0)
    else:
        detect = detect_image

        # Runs SEP detection
    bkg = sep.Background(detect)
    catalog = sep.extract(detect, 3, err=bkg.globalrms)
    bg_rms = []
    for img in [obs_lr.data, obs_hr.data]:
        if np.size(img.shape) == 3:
            bg_rms.append(np.array([sep.Background(band).globalrms for band in img]))
        else:
            bg_rms.append(sep.Background(img).globalrms)
        
    return catalog, bg_rms, detect_image

# Making catalog. 
# With the wavelet option on, only the first 3 wavelet levels are used for detection. Set to 1 for better detection


catalog, bg_rms, detect = makeThisCatalog(obs_hsc, obs_hst, lvl, wave)
catalog_hst, _ = makeCatalog([datas[1]], lvl=lvl, thresh = thresh, wave=wave)
catalog_hsc, _ = makeCatalog([datas[0]], lvl=lvl, thresh = thresh, wave=wave)

obs_hsc.weights = np.ones_like(datas[0]._images) / (bg_rms[0]**2)[:, None, None]
obs_hst.weights = np.ones_like(datas[1]._images) / (bg_rms[1]**2)[:, None, None] 

In [None]:
wcs_hst = datas[1].wcs
wcs_hsc = datas[0].wcs
obs_hsc, obs_hst = observations

xo,yo = catalog['x'], catalog['y']
xt,yt = catalog_hst['x'], catalog_hst['y']
xc,yc = catalog_hsc['x'], catalog_hsc['y']

hsc_norm = AsinhMapping(minimum=0, stretch=1, Q=10)
hst_norm = AsinhMapping(minimum=0, stretch=0.1, Q=5)

pixel_hst = np.stack((yt,xt), axis=1)
pixel_hsc = np.stack((yc,xc), axis=1)
pixels = np.stack((yo,xo), axis=1)

sky_coord_hst = obs_hsc.get_sky_coord(pixel_hst)
sky_coord = obs_hsc.get_sky_coord(pixel_hsc)
rac = sky_coord[:,0]
decc = sky_coord[:,1]

XYc = obs_hst.get_pixel(sky_coord)

Xc = XYc[:,0]
Yc = XYc[:,1]

plt.figure(figsize=(20,10))
plt.subplot(121)
plt.imshow((np.abs(im_hst[0]))**0.2)
#plt.plot(xt, yt, 'bx')
#plt.plot(Yc, Xc, 'yx')
plt.subplot(122)
plt.imshow(im_hsc[-1])
plt.plot(xc, yc, 'bx')
plt.show()


sources = []
spectra = np.eye(len(channels))+np.random.rand(len(channels),len(channels))*0.01
spectra[:,-1] = 1

spectra = []
for p in pixel_hst:
    s=[]
    
    for obs in observations:
        radec = obs_hst.get_sky_coord(p)
        for imag in obs.data:

            x, y = obs.get_pixel(radec).astype(int)
            s.append(imag[x, y])

    spectra.append([*s])
    
#print(spectra)
for i,p in enumerate(pixel_hst[:2]):
        radec = obs_hst.get_sky_coord([250,200])

        sources.append(scarlet.StarletSource(model_frame, 
                       radec, 
                       observations,
                       thresh=0.01,
                       spectrum = np.array(spectra[i]),
                       boxsize = 500,
                       #                      starlet_thresh=5.e-3
                                             
                      ))
#scarlet.initialization.set_spectra_to_match(sources, observations)
scarlet.display.show_scene(sources, 
                   norm = hst_norm, 
                   observation=obs_hst, 
                   show_rendered=True, 
                   show_observed=True, 
                   show_residual=True)
plt.show()
scarlet.display.show_scene(sources, 
                   norm = hsc_norm, 
                   observation=obs_hsc, 
                   show_rendered=True, 
                   show_observed=True, 
                   show_residual=True)
plt.show()



In [None]:
blend = scarlet.Blend(sources, observations)
for i in range(100):
    print(i)
    blend.fit(1, e_rel = 5.e-6)

    scarlet.display.show_scene(sources, 
                   norm = hst_norm, 
                   observation=obs_hst, 
                   show_rendered=True, 
                   show_observed=True, 
                   show_residual=True)
    plt.show()
    scarlet.display.show_scene(sources, 
                   norm = hsc_norm, 
                   observation=obs_hsc, 
                   show_rendered=True, 
                   show_observed=True, 
                   show_residual=True)
    plt.show()




In [None]:
for i in range(100):
    print(i)
    blend.fit(5, e_rel = 5.e-6)

    scarlet.display.show_scene(sources, 
                   norm = hst_norm, 
                   observation=obs_hst, 
                   show_rendered=True, 
                   show_observed=True, 
                   show_residual=True)
    plt.show()
    scarlet.display.show_scene(sources, 
                   norm = hsc_norm, 
                   observation=obs_hsc, 
                   show_rendered=True, 
                   show_observed=True, 
                   show_residual=True)
    plt.show()

In [None]:
plt.plot(blend.loss[5:])
plt.xlabel('Iteration')
plt.ylabel('log-Likelihood')
plt.show()

model_frame = sources[0].frame
model = np.zeros(model_frame.shape)
for src in sources:
    model += src.get_model(frame=model_frame)
    
plt.imshow(scarlet.display.img_to_rgb(model[:-1, ::-1, ::-1], norm=hst_norm))
plt.show()



In [None]:

scarlet.display.show_sources(sources,  
                         norm = hst_norm,
                         observation=obs_hst,
                         show_rendered=True, 
                         show_observed=True,
                         add_boxes=True
                        )

scarlet.display.show_sources(sources,  
                             norm = hst_norm,
                             observation=obs_hsc,
                             show_rendered=True, 
                             show_observed=True,
                             add_boxes=True,
                             show_model=True,
                            )