In [None]:
import scarlet
from scarlet.renderer import ResolutionRenderer
import galsim
from mr_tools import galsim_compare_tools as gct
from mr_tools import simulations as sim
from astropy import wcs as WCS
import scipy.stats as scs
import pickle
import numpy
import warnings
warnings.filterwarnings("ignore")

# Comparing galsim and scarlet resamplings 

In this notebook, we propose to run the scarlet and galsim resampling schemes on simulated images at different resolution. We generate images of galaxies at the resolutions of the HST, Euclid, Roman HSC and Rubin surveys and compare the resamplings from one resolution to another using both approach. Comparisons are conducted on a sample of 10000 galaxies from the COSMOS sample catalog.

In [None]:
def load_surveys():
    """Creates dictionaries for the HST, EUCLID, WFRIST, HCS anf LSST surveys
    that contain their names, pixel sizes and psf fwhm in arcseconds"""
    pix_ROMAN = 0.11
    pix_RUBIN = 0.2
    pix_HST = 0.06
    pix_EUCLID = 0.101
    pix_HSC = 0.167

    #Sigma of the psf profile in arcseconds.
    sigma_ROMAN = 0.11*np.array([1.86]) #https://arxiv.org/pdf/1702.01747.pdf Z-band
    sigma_RUBIN = np.array([0.297]) #https://www.lsst.org/about/camera/features
    sigma_EUCLID = np.array([0.16]) #https://sci.esa.int/documents/33859/36320/1567253682555-Euclid_presentation_Paris_1Dec2009.pdf
    sigma_HST = np.array([0.074]) #Source https://hst-docs.stsci.edu/display/WFC3IHB/6.6+UVIS+Optical+Performance#id-6.6UVISOpticalPerformance-6.6.1 800nm
    sigma_HSC = np.array([0.285]) #https://hsc-release.mtk.nao.ac.jp/doc/ deep+udeep


    EUCLID = {'name': 'EUCLID',
              'pixel': pix_EUCLID ,
              'psf': sigma_EUCLID,
              'channels': ['VIS'],
              'sky':np.array([22.9]),
              'exp_time': np.array([2260]),
              'zero_point': np.array([6.85])}
    HST = {'name': 'HST',
           'pixel': pix_HST,
           'psf': sigma_HST,
           'channels': ['f814w'],
           'sky':np.array([22]),
           'exp_time': np.array([3000]),
           'zero_point': np.array([20])}
    HSC = {'name': 'HSC',
           'pixel': pix_HSC,
           'psf': sigma_HSC,
           'channels': ['r'],
           'sky': np.array([20.6]),
           'exp_time': np.array([600]),
           'zero_point': np.array([87.74])}
    ROMAN = {'name': 'ROMAN',
             'pixel': pix_ROMAN,
             'psf': sigma_ROMAN,
             'channels': ['Y106'],
             'sky':np.array([22]), ## Not Checked!!!
             'exp_time': np.array([3000]),## Not Checked!!!
             'zero_point': np.array([26.41])}
    RUBIN = {'name': 'RUBIN',
             'pixel': pix_RUBIN,
             'psf': sigma_RUBIN,
             'channels': ['r'],
             'sky': np.array([21.2]),
             'exp_time': np.array([5520]),
             'zero_point': np.array([43.70])}

    return HST, EUCLID, ROMAN, HSC, RUBIN

In [None]:
%pylab inline
#Loading stuffs
data_dir='/Users/remy/Desktop/LSST_Project/GalSim/examples/data/COSMOS_23.5_training_sample/'

HST, EUCLID, ROMAN, HSC, RUBIN = load_surveys()

cat = galsim.COSMOSCatalog(dir=data_dir, file_name = 'real_galaxy_catalog_23.5.fits')
print(len(cat))
def mad(x):
    return scs.median_absolute_deviation(x)

npsf = 41

matplotlib.rc('image', cmap='gist_stern')
matplotlib.rc('image', interpolation='none')

In [None]:
# Shape of the low resolutino images
nlr = 90
# List of surveys
surveys = [HST,EUCLID, ROMAN, HSC, RUBIN]
# Scarlet-specific channels
channel_hr = ['hr']
channel_lr = ['lr']
channels = channel_lr+channel_hr
# Storage for the results
reconstructions = {'survey_lr': [], 'survey_hr': [], 'n_hr': [], 's_sdr': [], 'g_sdr': []}
# PSF size (pixels)


for gg, surveyhr in enumerate(surveys):
    for g, surveylr in enumerate(surveys[gg:]):
        print('HR', surveyhr)
        print('LR', surveylr)
        for i in range(1000):
            nhr = np.around(nlr*surveylr['pixel']/surveyhr['pixel'], decimals=3)
            if nhr-np.int(nhr) >= 0.5:
                nhr = np.int(np.floor(nhr))
            else:
                nhr = np.int(np.ceil(nhr))
            # Galsim setup:
            pic_hr, pic_lr = gct.mk_scene(surveyhr,
                                          surveylr,
                                          cat,
                                          (nhr,nhr),
                                          (nlr,nlr),
                                          1, 
                                          gal_type = 'real', 
                                          random_seds = False, 
                                          noise = False,
                                          use_cat = False,
                                          shift = False,
                                          index = i)
            
            wcs_hr = pic_hr.wcs
            wcs_lr = pic_lr.wcs
        
            data_hr = pic_hr.cube[0]/np.sum(pic_hr.cube[0])
            data_lr = pic_lr.cube[0]/np.sum(pic_lr.cube[0])
            
        
            psf_hr = scarlet.ImagePSF(np.array(pic_hr.psfs)[None, :,:])
            psf_lr = scarlet.ImagePSF(np.array(pic_lr.psfs)[None, :,:])
        
            ref = wcs_hr.wcs.crval.reshape((1,2))
            angle = 0
            
            ## GSO from psf_hr for galsim
            psf_hr_galsim = galsim.InterpolatedImage(galsim.Image(psf_hr.get_model()[0]), 
                                               scale = surveyhr['pixel'], use_true_center = False)
            ## deconvolution kernel for diff kernel
            deconv = galsim.Deconvolve(psf_hr_galsim)
            ## Interpolation of low resolution psf at high resolution
            psf_lr_hr = galsim.InterpolatedImage(galsim.Image(psf_lr.get_model()[0]), 
                                               scale = surveylr['pixel'], use_true_center = False)
            ## Difference kernel from galsim
            diff_gal = galsim.Convolve(deconv, psf_lr_hr)
            diff = diff_gal.drawImage(nx=npsf,
                                      ny=npsf, 
                                      scale=surveyhr['pixel']).array[None, :,:]

            # scarlet setup
            if i == 0:
                obs, frame = gct.setup_scarlet(data_hr, 
                                               data_lr, 
                                               wcs_hr, 
                                               wcs_lr, 
                                               psf_hr, 
                                               psf_lr, 
                                               channels, 
                                               'intersection')
                obs_lr, obs_hr = obs
                renderer = ResolutionRenderer(obs_lr, frame)
            # Scarlet timing
            
            scar_rec = renderer(data_hr[None, :,:])#obs_lr.render(data_hr[None,:,:])
            # Galsim timing
            gal_rec = gct.interp_galsim(data_hr, 
                                        data_lr, 
                                        diff_gal, 
                                        angle, 
                                        surveyhr['pixel'], 
                                        surveylr['pixel'])

            if i==500:
                figure(figsize = (25,5))
                subplot(141)
                plt.title('Scarlet residuals')
                imshow((scar_rec[0]-data_lr))
                colorbar()
                subplot(142)
                plt.title('Galsim residuals')
                imshow((gal_rec.array-data_lr))
                colorbar()
                subplot(143)
                plt.title('lr image')
                imshow((data_lr))
                colorbar()
                subplot(144)
                plt.title('hr image')
                imshow((data_hr))
                colorbar()
                show()
            
            reconstructions['survey_lr'].append(surveylr) 
            reconstructions['survey_hr'].append(surveyhr) 
            reconstructions['n_hr'].append(nhr) 
            reconstructions['s_sdr'].append(sim.SDR(data_lr,scar_rec))
            reconstructions['g_sdr'].append(sim.SDR(data_lr,gal_rec.array))
        
import pickle   
bfile = open('Precision_npsf='+str(npsf)+'.pkl', 'wb')
pickle.dump(reconstructions, bfile)
bfile.close()


In [None]:
reconstructions = pickle.load(open('Precision_npsf='+str(npsf)+'.pkl', 'rb')) 

survey_lr = np.array(reconstructions['survey_lr'])
survey_hr = np.array(reconstructions['survey_hr'])
n_hrs = np.array(reconstructions['n_hr'])
s_sdr = np.array(reconstructions['s_sdr'])
g_sdr = np.array(reconstructions['g_sdr'])


In [None]:
matplotlib.rc('xtick', labelsize=20) 
matplotlib.rc('ytick', labelsize=20) 

def plot_sdr(x, condition, label = [None, None]):
    errorbar(x, np.median(s_sdr[condition]), fmt = 'or', label = label[0],
                 yerr = mad(s_sdr[condition]))
    errorbar(x+0.05, np.median(g_sdr[condition]), fmt = 'xb', label = label[1],
                 yerr = mad(g_sdr[condition]))
    pass

def make_matrices(s, g):
    s_SDRmatrix = np.zeros((np.size(surveys), np.size(surveys)))
    g_SDRmatrix = np.zeros((np.size(surveys), np.size(surveys)))
    s_stdmatrix = np.zeros((np.size(surveys), np.size(surveys)))
    g_stdmatrix = np.zeros((np.size(surveys), np.size(surveys)))
    m_matrix = np.zeros((np.size(surveys), np.size(surveys)))
    per_cent = np.zeros((np.size(surveys), np.size(surveys)))
    for e, hr in enumerate(surveys):
        #figure(figsize = (10,10))
        #plt.title(hr['name'])
        for ee, lr in enumerate(surveys[e:]):
            s_SDRmatrix[e, ee+e] = np.mean(s[(survey_hr == hr)*(survey_lr == lr)])
            g_SDRmatrix[e, ee+e] = np.mean(g[(survey_hr == hr)*(survey_lr == lr)])
            s_stdmatrix[e, ee+e] = np.std(s[(survey_hr == hr)*(survey_lr == lr)])
            g_stdmatrix[e, ee+e] = np.std(g[(survey_hr == hr)*(survey_lr == lr)])
            m_matrix[e, ee+e] =  np.min(s[(survey_hr == hr)*(survey_lr == lr)]-g[(survey_hr == hr)*(survey_lr == lr)])
            per_cent[e, ee+e] = np.size(np.where(s[(survey_hr == hr)*(survey_lr == lr)]-g[(survey_hr == hr)*(survey_lr == lr)]<0))/np.size(s[(survey_hr == hr)*(survey_lr == lr)])*100
        #  if ee == 0:
          #      plot_sdr(ee,(survey_hr == hr)*(survey_lr == lr), label = ['scarlet', 'galsim'])
          #  else:
          #      plot_sdr(ee,(survey_hr == hr)*(survey_lr == lr))
   
        names = [l['name'] for j,l in enumerate(surveys[e:])]
    return s_SDRmatrix, g_SDRmatrix, s_stdmatrix, g_stdmatrix, m_matrix, per_cent

s_SDRmatrix, g_SDRmatrix, s_stdmatrix, g_stdmatrix, m_matrix, pc =  make_matrices(s_sdr, g_sdr)
#    plt.xticks((np.arange(ee+1)), names)
#    legend(fontsize = 30)
#    show()
    

In [None]:
names = [s['name'] for s in surveys]

plt.figure(figsize = (20,8))
plt.subplot(121)
plt.title('Scarlet SDRs', fontsize = 25)
plt.imshow((s_SDRmatrix), vmin = np.min(g_SDRmatrix), vmax = 60, cmap = 'Blues', origin = 'lower')
plt.xticks(np.arange(np.size(surveys)), names)
plt.yticks(np.arange(np.size(surveys)), names)
plt.ylabel('HR survey', fontsize = 20)
plt.xlabel('LR survey', fontsize = 20)
plt.colorbar()


plt.subplot(122)
plt.title('Galsim SDRs', fontsize = 25)
plt.imshow(g_SDRmatrix, cmap = 'Blues', origin = 'lower', vmax = 60)
plt.xticks(np.arange(np.size(surveys)), names)
plt.yticks(np.arange(np.size(surveys)), names)
plt.ylabel('HR survey', fontsize = 20)
plt.xlabel('LR survey', fontsize = 20)
plt.colorbar()
plt.savefig('SDR_matrix.png')
plt.show()

plt.figure(figsize = (30,8))
plt.subplot(131)
plt.title('$\sigma_{SDR}$ Scarlet', fontsize = 25)
plt.imshow(s_stdmatrix, vmin = 0, vmax = 10, cmap = 'Blues', origin = 'lower')
plt.xticks(np.arange(np.size(surveys)), names)
plt.yticks(np.arange(np.size(surveys)), names)
plt.ylabel('HR survey', fontsize = 20)
plt.xlabel('LR survey', fontsize = 20)
plt.colorbar()

plt.subplot(132)
plt.title('$\sigma_{SDR}$ Galsim', fontsize = 25)
plt.imshow(g_stdmatrix, vmin = 0, vmax = 10, cmap = 'Blues', origin = 'lower')
plt.xticks(np.arange(np.size(surveys)), names)
plt.yticks(np.arange(np.size(surveys)), names)
plt.ylabel('HR survey', fontsize = 20)
plt.xlabel('LR survey', fontsize = 20)
plt.colorbar()

plt.subplot(133)
plt.title('Scarlet - Galsim SDRs', fontsize = 25)
plt.imshow(s_SDRmatrix-g_SDRmatrix, vmin = 0, vmax = 20, cmap = 'Blues', origin = 'lower')
plt.xticks(np.arange(np.size(surveys)), names)
plt.yticks(np.arange(np.size(surveys)), names)
plt.ylabel('HR survey', fontsize = 20)
plt.xlabel('LR survey', fontsize = 20)
plt.colorbar()
print(np.min(s_SDRmatrix-g_SDRmatrix))


plt.savefig('SDR_std.png')
plt.show()

In [None]:

plt.title('Min Scarlet - Galsim SDRs', fontsize = 25)
plt.imshow(m_matrix, vmin = -20, vmax = 20, cmap = 'seismic', origin = 'lower')
plt.xticks(np.arange(np.size(surveys)), names)
plt.yticks(np.arange(np.size(surveys)), names)
plt.ylabel('HR survey', fontsize = 20)
plt.xlabel('LR survey', fontsize = 20)
plt.colorbar()
plt.show()

plt.title('Min Scarlet - Galsim SDRs', fontsize = 25)
plt.imshow(pc, cmap = 'Blues', origin = 'lower')
plt.xticks(np.arange(np.size(surveys)), names)
plt.yticks(np.arange(np.size(surveys)), names)
plt.ylabel('HR survey', fontsize = 20)
plt.xlabel('LR survey', fontsize = 20)
plt.colorbar()
plt.show()

In [None]:
k=13

plt.plot(s_sdr[k*1000:1000*(k+1)], 'or')
plt.plot(g_sdr[k*1000:1000*(k+1)], 'ob')
plt.show()

plt.plot(s_sdr[k*1000:1000*(k+1)]-g_sdr[k*1000:1000*(k+1)], 'ob')
plt.show()

print((np.where(s_sdr[k*1000:1000*(k+1)]-g_sdr[k*1000:1000*(k+1)]<0)))

In [None]:
pic_hr, pic_lr = gct.mk_scene(HSC,
                                          RUBIN,
                                          cat,
                                          (nhr,nhr),
                                          (nlr,nlr),
                                          1, 
                                          gal_type = 'real', 
                                          random_seds = False, 
                                          noise = False,
                                          use_cat = False,
                                          shift = False,
                                          index = 9555)

plt.imshow((pic_hr.cube[0]))
plt.show()
plt.imshow((pic_lr.cube[0]))
plt.show()

In [None]:
surveyhr = HSC
surveylr = RUBIN

nhr = np.around(nlr*surveylr['pixel']/surveyhr['pixel'], decimals=3)
if nhr-np.int(nhr) >= 0.5:
    nhr = np.int(np.floor(nhr))
else:
    nhr = np.int(np.ceil(nhr))
print(nhr, nlr)
    
# Galsim setup:
pic_hr, pic_lr = gct.mk_scene(surveyhr,
                              surveylr,
                              cat,
                              (nhr,nhr),
                              (nlr,nlr),
                              1, 
                              gal_type = 'real', 
                              random_seds = False, 
                              noise = False,
                              use_cat = False,
                              shift = False,
                              index = 9555
                             )

wcs_hr = pic_hr.wcs
wcs_lr = pic_lr.wcs

data_hr = pic_hr.cube[0]/np.sum(pic_hr.cube[0])
data_lr = pic_lr.cube[0]/np.sum(pic_lr.cube[0])


psf_hr = scarlet.ImagePSF(np.array(pic_hr.psfs)[None, :,:])
psf_lr = scarlet.ImagePSF(np.array(pic_lr.psfs)[None, :,:])

ref = wcs_hr.wcs.crval.reshape((1,2))
angle = 0

## GSO from psf_hr for galsim
psf_hr_galsim = galsim.InterpolatedImage(galsim.Image(psf_hr.get_model()[0]), 
                                   scale = surveyhr['pixel'], use_true_center = False)
## deconvolution kernel for diff kernel
deconv = galsim.Deconvolve(psf_hr_galsim)
## Interpolation of low resolution psf at high resolution
psf_lr_hr = galsim.InterpolatedImage(galsim.Image(psf_lr.get_model()[0]), 
                                   scale = surveylr['pixel'], use_true_center = False)
## Difference kernel from galsim
diff_gal = galsim.Convolve(deconv, psf_lr_hr)
diff = diff_gal.drawImage(nx=npsf,
                          ny=npsf, 
                          scale=surveyhr['pixel']).array[None, :,:]

# scarlet setup

obs, frame = gct.setup_scarlet(data_hr, 
                                   data_lr, 
                                   wcs_hr, 
                                   wcs_lr, 
                                   psf_hr, 
                                   psf_lr, 
                                   channels, 
                                   'intersection')
obs_lr, obs_hr = obs
renderer = ResolutionRenderer(obs_lr, frame)
# Scarlet timing

scar_rec = renderer(data_hr[None, :,:])#obs_lr.render(data_hr[None,:,:])
# Galsim timing
gal_rec = gct.interp_galsim(data_hr, 
                            data_lr, 
                            diff_gal, 
                            angle, 
                            surveyhr['pixel'], 
                            surveylr['pixel'])

figure(figsize = (25,5))
subplot(141)
plt.title('Scarlet residuals')
imshow((scar_rec[0]-data_lr))
colorbar()
subplot(142)
plt.title('Galsim residuals')
imshow((gal_rec.array-data_lr))
colorbar()
subplot(143)
plt.title('lr image')
imshow(np.log10(data_lr))
colorbar()
subplot(144)
plt.title('hr image')
imshow(np.log10(data_hr))
colorbar()
show()

print(sim.SDR(data_lr,scar_rec))
print(sim.SDR(data_lr,gal_rec.array))