# Simulations for multi-resolution deblending

In this notebook I test multi-resolution on simulated images using the galsim package.

In [None]:
import scarlet
import galsim
from astropy import wcs as WCS
import time
from mr_tools import galsim_compare_tools as gct
from mr_tools.simulations import Simulation, load_surveys, chi
import proxmin
import pickle

# Import Packages and setup
import numpy as np
import scarlet.display
from scarlet.display import AsinhMapping
from scarlet import Starlet
from scarlet.wavelet import mad_wavelet
import scipy.stats as scs
from functools import partial
from scarlet_extensions.initialization.detection import makeCatalog, Data
from scarlet_extensions.scripts.runner import Runner
import warnings
warnings.filterwarnings("ignore")

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
# use a better colormap and don't interpolate the pixels
matplotlib.rc('image', cmap='gist_stern')
matplotlib.rc('image', interpolation='none')

In [None]:
%pylab inline
# Setup: declaring survey properties, loading catalog and making sure we have pretty colorbars
data_dir=galsim.meta_data.share_dir

HST, EUCLID, ROMAN, HSC, RUBIN = load_surveys()
print(RUBIN)
center_ra = 19.3*galsim.hours     # The RA, Dec of the center of the image on the sky
center_dec = -33.1*galsim.degrees

cat = galsim.COSMOSCatalog(dir=data_dir, file_name = 'real_galaxy_catalog_23.5_example.fits')

In [None]:
# Generate simulations
hr_dict = EUCLID
lr_dict = RUBIN

nlr = 60
nhr = np.int(np.around(nlr*lr_dict['pixel']/hr_dict['pixel'], decimals = 3))
print(nlr, nhr)
ngal = np.int(np.random.rand(1)*10)
try: 
    skip
    pics = pickle.load(open("./pictures.pkl", "rb" ))
except:
    pics = gct.mk_scene(hr_dict, 
                        lr_dict, 
                        cat, 
                        (nhr,nhr), 
                        (nlr,nlr), 
                        3, 
                        gal_type = 'real',
                        pt_fraction = 0,
                        magmin = 20,
                        magmax = 29,
                        shift=True)
    f = open("pictures.pkl","wb")
    pickle.dump(pics, f)
    f.close()
pic_hr, pic_lr = pics
shifts = np.array(pic_hr.shifts)

wcs_hr = pic_hr.wcs
wcs_lr = pic_lr.wcs

hr = pic_hr.cube
lr = pic_lr.cube

gs_hr = pic_hr.galaxies
gs_lr = pic_lr.galaxies

psf_hr = np.array(pic_hr.psfs)
psf_lr = np.array(pic_lr.psfs)


In [None]:
# Channels
channels_hr = hr_dict['channels']
channels_lr = lr_dict['channels']
n,n1,n2 = np.shape(hr)

# Scale the HST data
_,n1,n2 = np.shape(hr)
# Scale the HSC data
r, N1, N2 = lr.shape

In [None]:
# Detectino of sources
data_hr =  Data(hr, wcs_hr, scarlet.ImagePSF(psf_hr), channels_hr)
data_lr =  Data(lr, wcs_lr, scarlet.ImagePSF(psf_lr), channels_lr)

datas = [data_lr, data_hr]

model_psf_hr = scarlet.GaussianPSF(sigma=(0.8,)*len(channels_hr), boxsize=9)
model_psf_lr = scarlet.GaussianPSF(sigma=(0.8,)*len(channels_lr), boxsize=9)
print(psf_hr.shape, psf_lr.shape)


In [None]:
#Results of the detection

# Create a color mapping for the HSC image
lr_norm = AsinhMapping(minimum=-10, stretch=10, Q=10)
hr_norm = AsinhMapping(minimum=-1, stretch=10, Q=5)
# Get the source coordinates from the HST catalog
xtrue, ytrue = shifts[:,0], shifts[:,1]

# Convert the HST coordinates to the HSC WCS
ratrue, dectrue = wcs_hr.wcs_pix2world(ytrue,xtrue,0)
catalog_true = np.array([ratrue, dectrue]).T

Ytrue, Xtrue = wcs_lr.wcs_world2pix(ratrue, dectrue,0)

# Map the HSC image to RGB
img_rgb = scarlet.display.img_to_rgb(lr, norm = lr_norm)
# Apply Asinh to the HST data
hr_img = scarlet.display.img_to_rgb(hr, norm=hr_norm)

plt.figure(figsize=(15,30))
plt.subplot(121)
plt.imshow(img_rgb)
#plt.axis('off')
plt.plot(Xtrue,Ytrue, 'xk', label = 'true positions')
plt.legend()

plt.subplot(122)
plt.imshow(hr_img)
#plt.axis('off')
plt.plot(xtrue,ytrue, 'xk', label = 'true positions')
plt.legend()
plt.show()

In [None]:

model_frame = scarlet.Frame(
    hr.shape,
    psf=model_psf_hr,
    channels=channels_hr)

observation = scarlet.Observation(
    hr, 
    psf=scarlet.ImagePSF(psf_hr),
    channels=channels_hr).match(model_frame)

sources = []
for i in range(len(xtrue)):
    new_source = scarlet.ExtendedSource(model_frame, (ytrue[i]
                                                      , xtrue[i]), observation)
    sources.append(new_source)
blend = scarlet.Blend(sources, observation)
blend.fit(200, e_rel=1e-6)
scarlet.display.show_scene(sources, 
                           norm=hr_norm, 
                           observation=observation, 
                           show_rendered=True, 
                           show_observed=True, 
                           show_residual=True)
plt.show()

model_frame = sources[0].frame
model = np.zeros(model_frame.shape)
for src in sources:
    model += src.get_model(frame=model_frame)
    

model = observation.render(model)
extent = scarlet.display.get_extent(observation.bbox)


In [None]:

model_frame = scarlet.Frame(
    lr.shape,
    psf=model_psf_lr,
    channels=channels_lr)

observation = scarlet.Observation(
    lr, 
    psf=scarlet.ImagePSF(psf_lr),
    channels=channels_lr).match(model_frame)


sources = []
for i in range(len(Xtrue)):
    new_source = scarlet.ExtendedSource(model_frame, (Ytrue[i], Xtrue[i]), observation)
    sources.append(new_source)
blend = scarlet.Blend(sources, observation)

blend.fit(200, e_rel=1e-8)
plt.plot(np.log10(np.array(np.abs(blend.loss))))
plt.show()

scarlet.display.show_scene(sources, 
                           norm = AsinhMapping(minimum=-10, stretch=10, Q=10), 
                           observation=observation, 
                           show_rendered=True, 
                           show_observed=True, 
                           show_residual=True)
plt.show()

s = sources[0].get_model(frame=model_frame)
model = observation.render(s)


res = lr-model
res /= np.max(res)

pos = np.where(res == np.max(res))


In [None]:
norms = [lr_norm, hr_norm]
try: 
    runners = pickle.load(open("./runners_60.pkl", "rb" ))
except:
    print("File not found.")
    run_multi = Runner(datas, model_psf_hr, ra_dec = catalog_true)
    run_hr = Runner([data_hr], model_psf_hr, ra_dec = catalog_true)
    run_lr = Runner([data_lr], model_psf_lr, ra_dec = catalog_true)
    runners = [run_lr, run_hr, run_multi]
    fr = open("./runners_60.pkl","wb")
    pickle.dump(runners, fr)
    fr.close()

In [None]:

sim = Simulation(cat, runners, ngal = 10, cats = [True]*3, hr_dict=hr_dict, lr_dict=lr_dict, n_lr=nlr)

print(sim.runners[-1].frame.shape)
try:
    sim.results = pickle.load(open("./sim_results.pkl", "rb" ))
    sim.plot()
except:
    print("File not found")

sim.run(5, plot = True, norms = norms, init_param=True)
sim.plot()


In [None]:
f = open("sim_results.pkl","wb")
pickle.dump(sim.results, f)
f.close()
    
for i in range(100):
    sim.run(5, init_param=True)
    sim.plot()
    f = open("sim_results.pkl","wb")
    pickle.dump(sim.results, f)
    f.close()


diff =  sim.runners[-1].observations[0]._diff_kernels[0]
diff_lr = sim.runners[0].observations[0]._diff_kernels[0]
diff_hr = sim.runners[1].observations[0]._diff_kernels[0]
from mr_tools.pictures import Pictures

In [None]:
import galsim

dirac = galsim.Gaussian(sigma = 1.e-20).withFlux(1)

star = galsim.Convolve(dirac, pic1.psfs_obj[0]).drawImage(nx=51,
                                          ny=51,
                                          method = 'real_space',
                                          use_true_center = True,
                                          scale = 0.1).array

psf = pic1.psfs_obj[0].withFlux(1).drawImage(nx=51,
                                 ny=51,
                                 method = 'real_space',
                                 use_true_center = True,
                                 scale = 0.1).array
plt.imshow(star)
plt.colorbar()
plt.show()
plt.imshow((star-psf))
plt.colorbar()
plt.show()