In [None]:
import scarlet
from scarlet.renderer import ResolutionRenderer
import galsim
from astropy import wcs as WCS
from mr_tools import galsim_compare_tools as gct
import time
import pickle

print(galsim.meta_data.share_dir)

In [None]:
def load_surveys():
    """Creates dictionaries for the HST, EUCLID, WFRIST, HCS anf LSST surveys
    that contain their names, pixel sizes and psf fwhm in arcseconds"""
    pix_ROMAN = 0.11
    pix_RUBIN = 0.2
    pix_HST = 0.06
    pix_EUCLID = 0.101
    pix_HSC = 0.167

    #Sigma of the psf profile in arcseconds.
    sigma_ROMAN = 0.11*np.array([1.86]) #https://arxiv.org/pdf/1702.01747.pdf Z-band
    sigma_RUBIN = np.array([0.297]) #https://www.lsst.org/about/camera/features
    sigma_EUCLID = np.array([0.16]) #https://sci.esa.int/documents/33859/36320/1567253682555-Euclid_presentation_Paris_1Dec2009.pdf
    sigma_HST = np.array([0.074]) #Source https://hst-docs.stsci.edu/display/WFC3IHB/6.6+UVIS+Optical+Performance#id-6.6UVISOpticalPerformance-6.6.1 800nm
    sigma_HSC = np.array([0.285]) #https://hsc-release.mtk.nao.ac.jp/doc/ deep+udeep


    EUCLID = {'name': 'EUCLID',
              'pixel': pix_EUCLID ,
              'psf': sigma_EUCLID,
              'channels': ['VIS'],
              'sky':np.array([22.9]),
              'exp_time': np.array([2260]),
              'zero_point': np.array([6.85])}
    HST = {'name': 'HST',
           'pixel': pix_HST,
           'psf': sigma_HST,
           'channels': ['f814w'],
           'sky':np.array([22]),
           'exp_time': np.array([3000]),
           'zero_point': np.array([20])}
    HSC = {'name': 'HSC',
           'pixel': pix_HSC,
           'psf': sigma_HSC,
           'channels': ['r'],
           'sky': np.array([20.6]),
           'exp_time': np.array([600]),
           'zero_point': np.array([87.74])}
    ROMAN = {'name': 'ROMAN',
             'pixel': pix_ROMAN,
             'psf': sigma_ROMAN,
             'channels': ['Y106'],
             'sky':np.array([22]), ## Not Checked!!!
             'exp_time': np.array([3000]),## Not Checked!!!
             'zero_point': np.array([26.41])}
    RUBIN = {'name': 'RUBIN',
             'pixel': pix_RUBIN,
             'psf': sigma_RUBIN,
             'channels': ['r'],
             'sky': np.array([21.2]),
             'exp_time': np.array([5520]),
             'zero_point': np.array([43.70])}

    return HST, EUCLID, ROMAN, HSC, RUBIN

This notebook times the execution of the galsim and scarlet resampling routines

In [None]:
%pylab inline
data_dir=galsim.meta_data.share_dir
#Reference catalog for simulations
cat = galsim.COSMOSCatalog(dir=data_dir, file_name = 'real_galaxy_catalog_23.5_example.fits')
#Survey dictionaries
HST, EUCLID, ROMAN, HSC, RUBIN = load_surveys()
#CHannels
channel_hr = ['hr']
channel_lr = ['lr']
channels = channel_lr+channel_hr

# Shapes tested for timing
shapes = [200, 150, 100, 70, 50, 30, 20]
#List of surveys
surveys = [HST, EUCLID, ROMAN, HSC, RUBIN]
#Size of the psf (pixels on-a-side)
npsf = 41

matplotlib.rc('image', cmap='gist_stern')
matplotlib.rc('image', interpolation='none')
matplotlib.rc('xtick', labelsize=20) 

In [None]:
# Run galsim's sinc interpolation
sinc = 0

In [None]:

# Dictionary for storing results
structure = {'survey_hr':[], 'survey_lr': [], 'n_hr': [], 'n_lr': [], 's_mean': [], 's_std': [], 
             'g_mean': [], 'g_std': [], 'gs_mean': [], 'gs_std': []}


#Timing for all combinations of high-low resolution images and all shapes in `shapes`
for nhr in shapes:
    for k, surveyhr in enumerate(surveys):
        print('survey hr  ' + str(nhr), surveyhr)
        for g, surveylr in enumerate(surveys[k+1:]):
            print(surveylr)
            # Making sure that the lr image spans roughly the same area on the sky as the hr image.
            nlr = int(nhr*surveyhr['pixel']/surveylr['pixel'])
            
            # Make simulations, picks a galaxy at random
            r = np.random.rand(1)*90
            pic_hr, pic_lr = gct.mk_scene(surveyhr, 
                                          surveylr, 
                                          cat,
                                          (nhr, nhr), 
                                          (nlr, nlr), 
                                          1,
                                          'real',
                                          noise = False,
                                          random_seds=False,
                                          index = 1,
                                          use_cat = False)
            
            wcs_hr = pic_hr.wcs
            wcs_lr = pic_lr.wcs
        
            data_hr = pic_hr.cube[0]/np.sum(pic_hr.cube[0])
            data_lr = pic_lr.cube[0]/np.sum(pic_lr.cube[0])
            
        
            psf_hr = scarlet.ImagePSF(np.array(pic_hr.psfs)[0][None, :,:])
            psf_lr = scarlet.ImagePSF(np.array(pic_lr.psfs)[0][None, :,:])
        
            ref = wcs_hr.wcs.crval.reshape((1,2))
            angle = 0
            
            if g == 0:
                # scarlet setup
                obs, frame = gct.setup_scarlet(data_hr, data_lr, wcs_hr, wcs_lr, psf_hr, psf_lr, channels, 'intersection')
                obs_lr, obs_hr = obs
                renderer = ResolutionRenderer(obs_lr, frame)
             
            # Galsim setup:
            ## GSO from psf_hr for galsim
            psf_hr_galsim = galsim.InterpolatedImage(galsim.Image(psf_hr.get_model()[0]), 
                                               scale = surveyhr['pixel'], use_true_center = False)
            ## deconvolution kernel for diff kernel
            deconv = galsim.Deconvolve(psf_hr_galsim)
            ## Interpolation of low resolution psf at high resolution
            psf_lr_hr = galsim.InterpolatedImage(galsim.Image(psf_lr.get_model()[0]), 
                                               scale = surveylr['pixel'], use_true_center = False)
            ## Difference kernel from galsim
            diff_gal = galsim.Convolve(deconv, psf_lr_hr)
            
            if sinc == 0:
                # Scarlet timing
                t_s = %timeit -n 1000 -r 2 -o renderer(data_hr[None, :,:])
                s_mean = np.array(t_s.all_runs).mean()/t_s.loops
                s_std = np.array(t_s.all_runs).std()/t_s.loops 
                # Galsim timing
                t_g = %timeit -n 1000 -r 2 -o gct.interp_galsim(data_hr, data_lr, diff_gal, angle, surveyhr['pixel'], surveylr['pixel'])
                g_mean = np.array(t_g.all_runs).mean()/t_g.loops 
                g_std = np.array(t_g.all_runs).std()/t_g.loops
                
                #Storage
                structure['survey_hr'].append(surveyhr) 
                structure['survey_lr'].append(surveylr) 
                structure['n_hr'].append(nhr) 
                structure['n_lr'].append(nlr)
                structure['s_mean'].append(s_mean)
                structure['s_std'].append(s_std)
                structure['g_mean'].append(g_mean) 
                structure['g_std'].append(g_std)
            else:
                t_gs = %timeit -n 1000 -r 2 -o gct.interp_galsim_sinc(data_hr, data_lr, diff_gal, angle, surveyhr['pixel'], surveylr['pixel'])
                gs_mean = np.array(t_gs.all_runs).mean()/t_gs.loops 
                gs_std = np.array(t_gs.all_runs).std()/t_gs.loops
            
                #Storage
                structure['survey_hr'].append(surveyhr) 
                structure['survey_lr'].append(surveylr) 
                structure['n_hr'].append(nhr) 
                structure['n_lr'].append(nlr)
                structure['gs_mean'].append(gs_mean) 
                structure['gs_std'].append(gs_std)
            
            
    #Saving results
    if sinc == 0:
        afile = open('Timings_npsf='+str(npsf)+'.pkl', 'wb')
        pickle.dump(structure, afile)
        afile.close()
    else:
        afile = open('Timings_sinc_npsf='+str(npsf)+'.pkl', 'wb')
        pickle.dump(structure, afile)
        afile.close()

In [None]:
structure = pickle.load(open('Timings_npsf='+str(npsf)+'.pkl', 'rb'))
structure_sinc = pickle.load(open('Timings_sinc_npsf='+str(npsf)+'.pkl', 'rb'))

#Reading results
n_hrs = np.array(structure['n_hr'])
n_lrs = np.array(structure['n_lr'])
s_mean = (np.array(structure['s_mean']))
s_std = np.array(structure['s_std'])
g_mean = np.array(structure['g_mean'])
g_std = np.array(structure['g_std'])
gs_mean = np.array(structure_sinc['gs_mean'])
gs_std = np.array(structure_sinc['gs_std'])

matplotlib.rc('xtick', labelsize=20) 
matplotlib.rc('ytick', labelsize=20) 
plt.close()
#Plotting results. Mean timing of resampling runs as a function of N
plt.figure(0, figsize = (15,15))
plt.subplots(figsize=(10, 6))
#ax2.errorbar(n_lrs, s_mean, yerr = s_std, capsize = 2, fmt = 'o', label = 'scarlet')
sc1 = plt.scatter(n_hrs, s_mean,c = n_lrs, marker = 'o', cmap = 'winter')
#ax2.errorbar(n_lrs, g_mean, yerr = g_std, capsize = 2, fmt = 's', label = 'galsim')
sc2 = plt.scatter(n_hrs, g_mean,c = n_lrs, marker = 's', cmap = 'copper')

sc3 = plt.scatter(n_hrs, gs_mean,c = n_lrs, marker = 'x', cmap = 'summer')
plt.yscale('log')
cbar = plt.colorbar(sc1)
plt.colorbar(sc2)
plt.colorbar(sc3)
cbar.set_label('M (# of pixels)', rotation=270, labelpad = 20, fontsize = 20)
# Adding plotting parameters
plt.title('Mean timing of resampling runs', fontsize=20)
plt.xlabel('N (# of pixels)', fontsize=20)
plt.ylabel('t (s)', fontsize=20)
plt.ylim([0.001,1])
plt.show()

#Plotting results. Mean timing of resampling runs as a function of M
plt.figure(1, figsize = (15,15))
plt.subplots(figsize=(10, 6))
#ax2.errorbar(n_lrs, s_mean, yerr = s_std, capsize = 2, fmt = 'o', label = 'scarlet')
sc1 = plt.scatter(n_lrs, s_mean,c = n_hrs, marker = 'o', cmap = 'winter')
#ax2.errorbar(n_lrs, g_mean, yerr = g_std, capsize = 2, fmt = 's', label = 'galsim')
sc2 = plt.scatter(n_lrs, g_mean,c = n_hrs, marker = 's', cmap = 'copper')

sc3 = plt.scatter(n_lrs, gs_mean,c = n_hrs, marker = 'x', cmap = 'summer')
plt.yscale('log')
cbar = plt.colorbar(sc1)
plt.colorbar(sc2)
plt.colorbar(sc3)
cbar.set_label('N (# ofpixels)', rotation=270, fontsize = 20, labelpad = 20)
# Adding plotting parameters
plt.title('Mean timing of resampling runs', fontsize=20)
plt.xlabel('M (# of pixels)', fontsize=20)
plt.ylabel('t (s)', fontsize=20)
plt.ylim(0.001,1)
plt.show()

#Plotting results. Mean timing of resampling runs as a function of M
plt.figure(1, figsize = (15,15))
plt.subplots(figsize=(10, 6))
for n in np.unique(n_hrs):
    plt.errorbar(n_lrs[n_hrs == n], s_mean[n_hrs == n],yerr = s_std[n_hrs == n], fmt = 'o', label = 'N = '+str(n))
plt.yscale('log')
plt.legend()
# Adding plotting parameters
plt.title('Mean timing of resampling runs', fontsize=20)
plt.xlabel('M (# of pixels)', fontsize=20)
plt.ylabel('t (s)', fontsize=20)
    
plt.show()

In [None]:
#Timing ratios as a function of M
plt.figure(2, figsize = (15,15))
fig2, ax2 = plt.subplots(figsize=(10, 6))
sc1 = ax2.scatter(n_lrs, s_mean/g_mean,c = n_hrs, marker = 'o', cmap = 'winter')
#plot((np.arange(5,200)), (np.arange(5,200))/30, label = 'M/30')
plt.legend()
yscale('log')
cbar = colorbar(sc1)
plot([0,220], [1,1], '--k')
cbar.set_label('N (# of pixels)', rotation=270, fontsize = 20, labelpad = 20)
# Adding plotting parameters
ax2.set_title('Timing ratios of resampling runs', fontsize=20)
ax2.set_xlabel('M (# of pixels)', fontsize=20)
ax2.set_ylabel('time ratios', fontsize=20)
savefig('galsim_scarlet_ratio.png')
show()

#Timing ratios as a function of N
figure(3, figsize = (15,15))
fig2, ax2 = plt.subplots(figsize=(10, 6))
sc1 = ax2.scatter(n_hrs, s_mean/g_mean,c = n_lrs, marker = 'o', cmap = 'winter')
yscale('log')
cbar = colorbar(sc1)
cbar.set_label('M (# of pixels)', rotation=270, fontsize = 20, labelpad = 20)
plot([0,220], [1,1], '--k')
# Adding plotting parameters
ax2.set_title('Timing ratios of resampling runs', fontsize=14)
plt.legend()
ax2.set_xlabel('N (# of pixels)', fontsize=20)
ax2.set_ylabel('time ratios', fontsize=20)
show()

#Colour representation of the colour ratios
fig2, ax2 = plt.subplots(figsize=(10, 8))
sc1 = ax2.scatter(n_hrs, n_lrs,c = np.log10(s_mean/g_mean), vmax = 2, vmin = -2, marker = 'o', cmap = 'seismic')
plt.ylim([0,220])
colorbar(sc1)
# Adding plotting parameters
ax2.set_title('Timing ratios of resampling runs', fontsize=14)
ax2.set_xlabel('N (# of pixels)', fontsize=20)
ax2.set_ylabel('M (# of pixels)', fontsize=20)
plt.savefig('measured_ratios.png')
show()

In [None]:
#Comparison with ggalsim sinc
#Timing ratios as a function of M
plt.figure(2, figsize = (15,15))
fig2, ax2 = plt.subplots(figsize=(10, 6))
sc1 = ax2.scatter(n_lrs,gs_mean/g_mean,c = n_hrs, marker = 'o', cmap = 'winter')
#plot((np.arange(5,200)), (np.arange(5,200))/30, label = 'M/30')
plt.legend()
yscale('log')
cbar = colorbar(sc1)
plot([0,220], [1,1], '--k')
cbar.set_label('N (# of pixels)', rotation=270, fontsize = 20, labelpad = 20)
# Adding plotting parameters
ax2.set_title('Timing ratios of resampling runs', fontsize=20)
ax2.set_xlabel('M (# of pixels)', fontsize=20)
ax2.set_ylabel('time ratios', fontsize=20)
savefig('galsim_sinc_scarlet_ratio.png')
show()

#Timing ratios as a function of N
figure(3, figsize = (15,15))
fig2, ax2 = plt.subplots(figsize=(10, 6))
sc1 = ax2.scatter(n_hrs, gs_mean/g_mean,c = n_lrs, marker = 'o', cmap = 'winter')
yscale('log')
cbar = colorbar(sc1)
cbar.set_label('M (# of pixels)', rotation=270, fontsize = 20, labelpad = 20)
plot([0,220], [1,1], '--k')
# Adding plotting parameters
ax2.set_title('Timing ratios of resampling runs', fontsize=14)
plt.legend()
ax2.set_xlabel('N (# of pixels)', fontsize=20)
ax2.set_ylabel('time ratios', fontsize=20)
show()

#Colour representation of the colour ratios
fig2, ax2 = plt.subplots(figsize=(10, 8))
sc1 = ax2.scatter(n_hrs, n_lrs,c = np.log10(gs_mean/g_mean), vmax = 2, vmin = -2, marker = 'o', cmap = 'seismic')
plt.ylim([0,220])
colorbar(sc1)
# Adding plotting parameters
ax2.set_title('Timing ratios of resampling runs', fontsize=14)
ax2.set_xlabel('N (# of pixels)', fontsize=20)
ax2.set_ylabel('M (# of pixels)', fontsize=20)
plt.savefig('measured_sinc_ratios.png')
show()

In [None]:
#theoretical version:
def galsim_timing(M, N):
    return N**2*(6**2+4**2*np.log(x*N))+M**2*(3+np.log(M))
def scarlet_timing(M, N):
    return N**2*((M+1)*np.log(N) + M*(M*1))

In [None]:
mnms = arange(20,201,10)

x,y = np.meshgrid(mnms, mnms)

plt.figure(figsize = (10,8))
plt.ylim([0,220])
plt.scatter(x,y, c = np.log10(scarlet_timing(x,y)/galsim_timing(x,y)), vmax = 2, vmin = -2, cmap = 'seismic')
#scatter(n_hrs, n_lrs,c = np.log10(s_mean/g_mean), marker = 'x', vmax = 1, vmin = -1, cmap = 'seismic')
plt.xlabel('M', fontsize = '30')
plt.ylabel('N', fontsize = '30')
plt.title('theoretical timing ratios', fontsize = '30')
cbar = plt.colorbar()
cbar.set_label('time ratios', rotation=270, fontsize = 25, labelpad = 20)
plt.savefig('theory_timing.png')
plt.show()

In [None]:
#Theoretical timing ratio for dominant components 
plt.figure(figsize = (10,8))
plt.ylim([0,220])
plt.scatter(x,y, c = np.log10(scarlet_timing(x,y)), cmap = 'seismic')
#scatter(n_hrs, n_lrs,c = np.log10(s_mean/g_mean), marker = 'x', vmax = 1, vmin = -1, cmap = 'seismic')
plt.xlabel('M', fontsize = '30')
plt.ylabel('N', fontsize = '30')
plt.title('dominant timing ratios', fontsize = '30')
cbar = plt.colorbar()
cbar.set_label('time ratios', rotation=270, fontsize = 25, labelpad = 20)
plt.show()