In [1]:
################# DTU Master Project #################
################# Jupyter Notebook script to be converted to .py file to run in HPC #################
import warnings
import numpy as np
from astropy.utils.data import download_file
from astropy.utils.data import clear_download_cache
from astropy.io import fits
import matplotlib.pyplot as plt
from astropy.modeling import models, fitting
from astropy.modeling.models import Sersic2D
from astropy.modeling.models import Gaussian2D
from astropy.visualization import LogStretch
from astropy.convolution import convolve, Gaussian2DKernel
from astropy.modeling.models import custom_model
from astropy.modeling import Fittable2DModel
from scipy.special import gammaincinv
from scipy.special import gamma
from astropy.io import fits
from IPython.display import Image
from astropy.cosmology import Planck15 as cosmo
from astropy.stats import sigma_clipped_stats
import os
import astropy.units as u
from astropy.coordinates import SkyCoord
import astropy.io.fits as pyfits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D
import scipy.ndimage as nd
import sep
import time
from pysersic import check_input_data
from pysersic.results import plot_image
from pysersic import FitSingle
from pysersic.loss import student_t_loss, gaussian_loss
from jax.random import PRNGKey # Need to use a seed to start jax's random number generation
from pysersic.results import plot_residual
from pysersic.priors import autoprior
from pysersic.multiband import FitMultiBandPoly
import jax
import arviz as az
import grizli
import grizli.catalog
from grizli import utils
import eazy
import eazy.hdf5
import scipy.stats as st
from matplotlib.colors import LogNorm
import matplotlib as mpl
import csv
import pandas as pd
import asdf
import corner
import matplotlib.lines as mlines
import matplotlib.ticker as mticker
from pysersic.priors import PySersicSourcePrior, estimate_sky
from skimage.measure import block_reduce

In [33]:
from platform import python_version

print(python_version())

3.12.9


# Definition of functions to load and save files + Full wht and Full exp

In [1]:
class CalcSize:
    def load_ims(N0, Nc, Nv, Nfilter, RA, DEC, IMSIZE, n):
        start = time.time()
        # Decide whether the interest is on science or all images
        if n == 0:
            names = ['sci', 'wht', 'exp']
        else: names = ['sci']
            
        img = {}
        Im_output = {}
        ra = RA
        dec = DEC
        imsize = IMSIZE
        if len(ra) != len(dec) or len(ra) != len(imsize) or len(Nc) != len(ra):
            warnings.warn('Size of RA, DEC or IMSIZE do not correspond. It must be len(RA) = len(DEC) = len(IMSIZE) = len(Nc).')
        if len(Nc) == 1 and len(RA) != 1:
            ra = np.array([RA])
        if len(Nc) == 1 and len(DEC) != 1:
            dec = np.array([DEC]) 
        if len(Nc) == 1 and len(IMSIZE) != 1:
            imsize = np.array([IMSIZE]) 
        if len(Nc) == 1 and len(Nfilter) != 1:
            Nfilter = np.array([Nfilter])
        # Cutout size must be given. If null hypothesis is assumed, the user puts zero in the corresponding list position: the code will make an estimate of a suitable size 
        for i in range(0, len(Nc)):
            if any(imsize[i][k] == 0 for k in range(len(imsize[i]))) or len(imsize[i]) == 0:
                imsize[i] = 3*np.ones(len(ra[Nc[i]]))
            
        Filters = {}  
        for ext in names:
            Im_output[ext] = {}
            for l in range(0, len(Nc)):
                Im_output[ext][Nc[l]] = {}
                for m in range(0, len(Nfilter[l])):
                    F = Nfilter[l][m]
                    Im_output[ext][Nc[l]][F] = []
                    ### Prepare file
                    _file = N0 + Nc[l] + '-grizli-' + Nv[l] + '-' + F + '_drc_' + ext + '.fits.gz'
                    print('Opening file ', _file,' ...')
                    local_path = download_file(_file, cache=True)
                    img[ext] = fits.open(local_path)
                    wcs = WCS(img[ext][0])
                    header = img[ext][0].header
                    print('Cutting file... ')
                    ### Cut file
                    ### GDS and GDN have different pixel sizes in arcsec for different wavelengths, thus normalize the size of the cutouts
                    for k in range(0, len(ra[l])):
                        if F in ['f200w-clear', 'f150w-clear'] and Nc[l] in ['gdn', 'gds']:
                            # Use N*0.5 if you wish to de-normalize it. In this case, no notmalization is applied.
                            side = u.Quantity(2*0.5*imsize[l][k]*u.arcsec, 2*0.5*imsize[l][k]*u.arcsec)
                            pos = SkyCoord(ra[l][k]*u.deg, dec[l][k]*u.deg,frame='fk5')
                            cutout = Cutout2D(img[ext][0].data, position=pos, size=side, wcs=wcs)
                            updated_header = cutout.wcs.to_header()
                        else:
                            side = u.Quantity(imsize[l][k]*u.arcsec, imsize[l][k]*u.arcsec)
                            pos = SkyCoord(ra[l][k]*u.deg, dec[l][k]*u.deg,frame='fk5')
                            cutout = Cutout2D(img[ext][0].data, position=pos, size=side, wcs=wcs)
                            updated_header = cutout.wcs.to_header()
                        #print('Done.')
                        ### ----------------
                        if header.get('PHOTMJSR') == None:
                            ## In case of missing parameter, it happened once... ##
                            header['PHOTMJSR'] = 0.4
                        updated_header['PHOTMJSR'] = header['PHOTMJSR']
                        updated_header['PHOTSCAL'] = header['PHOTSCAL']
                        if 'OPHOTFNU' in header:
                            updated_header['OPHOTFNU'] = header['OPHOTFNU']
                            updated_header['PHOTFNU'] = header['PHOTFNU']
                        # --- Define cutout name ---
                        Cname = F + '_' + str(ra[l][k]) + '_' + str(dec[l][k])
                        # --- Path needs to be given by the users
                        Path_to_cutout = '/work3/s240096/DTU_project/cutouts_hom/cutout_'
                        cutout_name = Path_to_cutout + Cname + '_' + ext + '.fits'
                        print('Saving file ', cutout_name,' ...')
                        cutout_hdul = fits.PrimaryHDU(data = cutout.data, header = updated_header)
                        cutout_hdul.writeto(cutout_name, overwrite = True)
    
                        #filename = f'cutout_'+Cname+'_'+ext+'.fits'
                        Im_temp = fits.open(cutout_name)
    
                        Im_output[ext][Nc[l]][F].append((Im_temp, str(ra[l][k])+'_'+str(dec[l][k])))
                    ## Clear cache at every filter step    
                    clear_download_cache()
                    ##
                        #cutout filter name, field name and extension are loaded together the image
        end = time.time()
        length = end - start
        print('Loading has implied ', length/60, ' minutes to run')
        return Im_output
        
    def GetMask(N0, Nc, Nfilter, ra, dec):
        for i in range(0, len(Nc)):
            Path_to_cutout = '/work3/s240096/DTU_project/cutouts_hom/cutout_'
            for j in range(0, len(Nfilter[i])):
                F = Nfilter[i][j]
                for k in range(0, len(ra[i])):
                    Cname = F + '_' + str(ra[i][k]) + '_' + str(dec[i][k])
                    img = fits.open(Path_to_cutout + Cname + "_sci.fits")[0].data.astype("f4")
                    rms = np.load(Path_to_cutout + Cname + "_sigma.npy")
                    ## Derive a mask based on the chosen band
                    cat, seg = sep.extract(img, thresh = 3., err = rms, segmentation_map = True,)
                    c = round(len(rms)*0.5)
                    obj_id = seg[c,c]
                    mask = seg.copy()
                    mask[np.where(seg ==obj_id)] = 0
                    mask[mask>=1] = 1
                    # --- Path needs to be given by the users
                    np.save(Path_to_cutout + Cname + "_mask.npy", mask)
                    print(Path_to_cutout + Cname + "_mask.npy")
                ## Clear cache at every filter step. Useful to free memory and avoid self-kill   
                clear_download_cache()
                ##
                
    def GetExpWht(N0, Nc, Nv, Nfilter, RA, DEC, IMSIZE):
    # Grow the exposure map to the original frame
        img = CalcSize.load_ims(N0, Nc, Nv, Nfilter, RA, DEC, IMSIZE, 0)
        if img == 0:
            return print('Deprecated')
        img['Full exp'] = {}
        img['Full wht'] = {}
        for l in range(0, len(Nc)):
            Path_to_cutout = '/work3/s240096/DTU_project/cutouts_hom/cutout_'
            img['Full exp'][Nc[l]] = {}
            img['Full wht'][Nc[l]] = {}
            for F in img['sci'][Nc[l]]:
                img['Full exp'][Nc[l]][F] = []
                img['Full wht'][Nc[l]][F] = []
                for i in range(0, len(img['sci'][Nc[l]][F])):
                    full_exp = np.zeros(img['sci'][Nc[l]][F][i][0][0].data.shape, dtype=int)
                    try:
                        full_exp[1::4, 1::4] += img['exp'][Nc[l]][F][i][0][0].data * 1
                    except ValueError:
                        full_exp[0::4, 0::4] += img['exp'][Nc[l]][F][i][0][0].data * 1
                    full_exp = nd.maximum_filter(full_exp, 4)
                    
                    # Make Full exp map
                    img['Full exp'][Nc[l]][F].append((fits.HDUList([fits.PrimaryHDU(data=full_exp)]), 
                                        img['exp'][Nc[l]][F][i][1]))
                    header = img['exp'][Nc[l]][F][i][0][0].header
                    # Multiplicative factors that have been applied since the original count-rate images
                    phot_scale = 1.
                    for k in ['PHOTMJSR','PHOTSCAL']:
                        print(f'{k} {header[k]:.3f}')
                        phot_scale /= header[k]
                    if 'OPHOTFNU' in header:
                        phot_scale *= header['PHOTFNU'] / header['OPHOTFNU']
                    # "effective_gain" = electrons per DN of the mosaic
                    effective_gain = ( phot_scale * full_exp )
                    # Poisson variance in mosaic DN
                
                    if np.min(np.abs(effective_gain)) == 0:
                        effective_gain = effective_gain + 1e-6
                    var_poisson_dn = np.maximum(img['sci'][Nc[l]][F][i][0][0].data, 0) / effective_gain
                    # Original variance from the `wht` image = RNOISE + BACKGROUND
                    if np.min(np.abs(img['wht'][Nc[l]][F][i][0][0].data))==0:
                        img['wht'][Nc[l]][F][i][0][0].data = img['wht'][Nc[l]][F][i][0][0].data + 1e-6
                    
                    var_wht = 1 / img['wht'][Nc[l]][F][i][0][0].data
                    # New total variance
                    var_total = var_wht + var_poisson_dn
                    full_wht = 1 / var_total
                    # Null weights
                    full_wht[var_total <= 0] = 0
                    img['Full wht'][Nc[l]][F].append(( fits.HDUList([fits.PrimaryHDU(data = full_wht, header = img['wht'][Nc[l]][F][i][0][0].header)]), 
                                            img['wht'][Nc[l]][F][i][1]))
                    sigma = np.where(img["Full wht"][Nc[l]][F][i][0][0].data > 0, 1 / np.sqrt(img["Full wht"][Nc[l]][F][i][0][0].data), 0.1)
                    np.save(Path_to_cutout + F +'_'+ img['wht'][Nc[l]][F][i][1] + "_sigma.npy", sigma)
                    print(Path_to_cutout + F +'_'+ img['wht'][Nc[l]][F][i][1] + "_sigma.npy")
                print('Making masks... ')
        clear_download_cache()
        CalcSize.GetMask(N0, Nc, Nfilter, RA, DEC)
        return img
    
    def GetSize(N0, Nc, Nv, Nfilter, RA, DEC, IMSIZE, psf_path, quick_size, plot_im_mask_psf, plot_resid, prior_sel):
        start = time.time()
        ra = RA
        dec = DEC
        imsize = IMSIZE
        if len(ra) != len(dec) or len(ra) != len(imsize) or len(Nc) != len(ra):
            warnings.warn('Size of RA, DEC or IMSIZE do not correspond. It must be len(RA) = len(DEC) = len(IMSIZE) = len(Nc).')
        if len(Nc) == 1 and len(RA) != 1:
            ra = np.array([RA])
        if len(Nc) == 1 and len(DEC) != 1:
            dec = np.array([DEC]) 
        if len(Nc) == 1 and len(IMSIZE) != 1:
            imsize = np.array([IMSIZE]) 
        if len(Nc) == 1 and len(Nfilter) != 1:
            Nfilter = np.array([Nfilter])
        asdf_store_name = '/work3/s240096/DTU_project/asdf_files_bulge'
        try:
            os.makedirs(asdf_store_name)
            print(f"Directory asdf_files created successfully.")
        except FileExistsError:
            print(f"Directory asdf_files already exists.")
        except PermissionError:
            print(f"Permission denied: Unable to create asdf_files.")
        except Exception as e:
            print(f"An error occurred: {e}")
        #######################################################
        ## Comment the following line if files already exist ##
        #######################################################
        im = CalcSize.GetExpWht(N0, Nc, Nv, Nfilter, RA, DEC, IMSIZE)
        
        for i in range(0, len(Nc)):
            Path_to_cutout = "/work3/s240096/DTU_project/cutouts_hom/cutout_"
            rkey = jax.random.PRNGKey(i)
            for j in range(0, len(Nfilter[i])):
                F = Nfilter[i][j]
                directory_name = f'/work3/s240096/DTU_project/residual_plots_bulge/{F}'
                cwd = os.getcwd()
                try:
                    os.makedirs(f'{directory_name}')
                    print(f"Directory '{directory_name}' created successfully.")
                except FileExistsError:
                    print(f"Directory '{directory_name}' already exists.")
                except PermissionError:
                    print(f"Permission denied: Unable to create '{directory_name}'.")
                except Exception as e:
                    print(f"An error occurred: {e}")
                    
                for k in range(0, len(ra[i])):
                    Cname = F + '_' + str(ra[i][k]) +'_'+ str(dec[i][k])
                    ##############################################################
                    ## Comment the following line if cutout files already exist ##
                    ##############################################################
                    image = im['sci'][Nc[i]][F][k][0][0].data
                    ##################################################################
                    ## Uncomment the following two lines if the files already exist ##
                    ##################################################################
                    # image = fits.open(Path_to_cutout + Cname + '_' + 'sci' + '.fits')
                    # image = image[0].data.astype("f4")
                    ##################################################################
                    mask = np.load(Path_to_cutout + Cname + "_mask.npy").astype("f4")
                    sig = np.load(Path_to_cutout + Cname + "_sigma.npy").astype("f4")
                    ## Clear cache
                    clear_download_cache()
                    ###
                    PSFname = psf_path + Nc[i] + '-grizli-' + Nv[i] + '-' + F + '_drc_cat_star'
                    psf_raw = fits.open(PSFname + '_psf.psf')
                    ###
                    psf = psf_raw[1].data[0][0][0].astype("f4")
                    if (len(psf) - len(mask) > 0):
                        Np = len(psf) - len(mask)
                        if Np == 1:
                            lim1 = int(0.5*Np)+1
                        else: lim1 = int(0.5*Np)
                    else: 
                        Np=0
                    # Normalize resized PSF
                    psf = psf[lim1-1:len(psf)-lim1-1,lim1-1:len(psf)-lim1-1]/np.sum(psf[lim1-1:len(psf)-lim1-1, lim1-1:len(psf)-lim1-1])
                    print('size difference between psf and mask was', Np)
                    # Start defyining which method needs to be used in order to calculate sizes
                    if plot_im_mask_psf == True:
                        fig, ax = plot_image(image, mask, sig, psf)
                    try: 
                        check_input_data(data = image, rms = sig, psf = psf, mask = mask)
                    except Warning:
                        print(f'pysersic.exceptions.RMSWarning: Source{Cname} got invalid StudentT distribution')
                        continue
                    ## Choose the profile type    
                    profile = 'sersic_exp'
                    ## --- Set autoprior ---
                    if prior_sel == False:
                        prior = autoprior(image = image, profile_type = profile, mask = mask, sky_type = 'none')
                        fitter = FitSingle(data = image, rms = sig, mask = mask, psf = psf, prior = prior, loss_func = student_t_loss)
                    else: 
                        sky_med, sky_std, n_pix = estimate_sky(image=image, mask = mask)
                        sky_med_unc = sky_std/np.sqrt(n_pix) 
                        ## --- Set custom prior --- ##
                        ##############################
                        custom_prior = PySersicSourcePrior(profile_type = profile, sky_type= 'none', sky_guess=sky_med, sky_guess_err= 2*sky_med_unc)
                        ## --- Assume pixel size has 0.04"/pixel. It needs to be modified upon occasion
                        xc = imsize[i][k]/0.04*0.5
                        yc = imsize[i][k]/0.04*0.5
                        # #  --- This can be extended by the user, both in filters and fields. GDS and GDN have 0.02"/pixel,
                        ##       thus if the imsizeis not normalized, the center is will be shifted 
                        # if (Nc[i] == 'gds' or Nc[i] == 'gdn') and (F == 'f200w-clear' or F == 'f150w-clear'):
                        #     xc = 2*xc
                        #     yc = 2*yc
                        
                        ## ------- Reasonable value for the flux --------- ##
                        cat, seg = sep.extract(image, thresh = 3., err = sig, segmentation_map = True,)
                        if len(cat['a'])==0 or len(cat['flux'])==0:
                            continue
                        flux_guess = cat['flux'][0]
                        sem_maj_axis = cat['a'][0]
                        custom_prior.set_gaussian_prior('r_eff', sem_maj_axis, 0.4*sem_maj_axis)
                        custom_prior.set_gaussian_prior('flux', flux_guess, 0.4*flux_guess)
                        custom_prior.set_gaussian_prior('xc', xc, 2)
                        custom_prior.set_gaussian_prior('yc', yc, 2)
                        custom_prior.set_uniform_prior('n', 0.5, 9.0)
                        custom_prior.set_uniform_prior('ellip', 0.0, 1.0)
                        custom_prior.set_uniform_prior('theta', 0, 2*np.pi)
    
                        rkey, subkey = jax.random.split(rkey)
                        fitter = FitSingle(data = image, rms = sig, mask = mask, psf = psf, prior = custom_prior, loss_func = student_t_loss)

                    try :
                        map_params = fitter.find_MAP(rkey = rkey) # To be given as output with some flag selection?
                        #output_samp[Nc[i]][Cname].append(map_params)
                        if plot_resid == True:
                            fig, ax = plot_residual(image, map_params['model'], mask = mask, vmin = -1, vmax = 1)
                            plt.savefig(directory_name+'/'+Cname+'.pdf')
                        else:
                            fig, ax = plot_residual(image, map_params['model'], mask = mask, vmin = -1, vmax = 1)
                            plt.savefig(directory_name+'/'+Cname+'.pdf')
                            plt.close(fig)
                    except ValueError:
                        print(f'Source{Cname} got invalid StudentT distribution')
                        continue
                    # Heavy version of the previous one: if user does not have timing issues
                
                    if quick_size == False:
                        # for speed, in bulge-disk separation use svi-method
                        if profile == 'sersic_exp' or profile == 'doublesersic':
                            res = fitter.estimate_posterior(rkey=PRNGKey(1001), method="laplace")
                            summary = fitter.svi_results.summary()
                            fitter.svi_results.save_result(f'{asdf_store_name}/{Cname}.asdf')
                        else:
                            fitter.sample(rkey = rkey)
                            sampling_res = fitter.sampling_results
                            fitter.sampling_results.save_result(f'{asdf_store_name}/{Cname}.asdf')
                        #sampling_res.summary()
            # ## Clear cache at every filter step    
            # clear_download_cache()
            # ##
        end = time.time()
        length = end - start
        print('Size calculation has implied ', length/60, ' minutes to run')
        
    def FitBands(N0, Nc, Nv, Nfilter, RA, DEC, IMSIZE, psf_path, wvList, n_order, plot_fits, plot_resid, prior_sel):
    #P = CalcSize.GetSize(N0, Nc, Nv, Nfilter, RA, DEC, IMSIZE, psf_path)
        wv_list = wvList
        start = time.time()
        ra = RA
        dec = DEC
        imsize = IMSIZE
        output_samp = {}
        if len(ra) != len(dec) or len(ra) != len(imsize):
            warnings.warn('Size of RA, DEC or IMSIZE do not correspond. It must be len(RA) = len(DEC) = len(IMSIZE).')
        if len(Nc) == 1 and len(RA) != 1:
            ra = np.array([RA])
        if len(Nc) == 1 and len(DEC) != 1:
            dec = np.array([DEC]) 
        if len(Nc) == 1 and len(IMSIZE) != 1:
            imsize = np.array([IMSIZE]) 
        if len(Nc) == 1 and len(Nfilter) != 1:
            Nfilter = np.array([Nfilter])
            
        fitter_dict = {}
        ind_res_dict = {}
        sort_wv_list = []
        aux_wv_list = []
        cwd = os.getcwd()

        directory_name = '/work3/s240096/DTU_project/fitted_bands_hom'
        try:
            os.makedirs(f'{directory_name}')
            print(f"Directory '{directory_name}' created successfully.")
        except FileExistsError:
            print(f"Directory '{directory_name}' already exists.")
        except PermissionError:
            print(f"Permission denied: Unable to create '{directory_name}'.")
        except Exception as e:
            print(f"An error occurred: {e}")
        asdf_store_name = '/work3/s240096/DTU_project/asdf_files_multi'
        try:
            os.makedirs(asdf_store_name)
            print(f"Directory asdf_files created successfully.")
        except FileExistsError:
            print(f"Directory asdf_files already exists.")
        except PermissionError:
            print(f"Permission denied: Unable to create asdf_files.")
        except Exception as e:
            print(f"An error occurred: {e}")
        #######################################################
        ## Comment the following line if files already exist ##
        #######################################################
        im = CalcSize.GetExpWht(N0, Nc, Nv, Nfilter, RA, DEC, IMSIZE)
        for i in range(len(Nc)):
            aux_wv_list.append([str(np.round( np.array(wvList[i][b])*100 )) for b, band in enumerate(wvList[i])])
            
            for j in range(len(Nfilter[i])):
                F = Nfilter[i][j]
                f = aux_wv_list[i][j]
                # F and f should have corresponding ordering of filters. The first filter in the list will be used as prior for the fitting
                ###
                fitter_dict[f] = {} 
                ind_res_dict[f] = {} 
                ###
                rkey = jax.random.PRNGKey(5+3*j)
                for k in range(len(ra[i])):
                    Cname = F + '_' + str(ra[i][k]) +'_'+ str(dec[i][k])
                    Path_to_cutout = "/work3/s240096/DTU_project/cutouts_hom/cutout_"
                    # Load image
                    ##############################################################
                    ## Comment the following line if cutout files already exist ##
                    ##############################################################
                    image = im['sci'][Nc[i]][F][k][0][0].data.astype("f4")
                    ##################################################################
                    ## Uncomment the following two lines if the files already exist ##
                    ##################################################################
                    # image = fits.open(Path_to_cutout + Cname + '_' + 'sci' + '.fits')
                    # image = image[0].data.astype('f4')
                    #---------------------------------
                    # Load mask and sigma
                    mask = np.load(Path_to_cutout + Cname + '_mask.npy').astype("f4")
                    sig = np.load(Path_to_cutout + Cname + "_sigma.npy").astype("f4")
                    ## Clear cache
                    clear_download_cache()
                    ##
                    # Load PSF
                    PSFname = psf_path + Nc[i] + '-grizli-' + Nv[i] + '-' + F + '_drc_cat_star'
                    psf_raw = fits.open(PSFname + '_psf.psf')
                    psf = psf_raw[1].data[0][0][0].astype("f4")
                    ## BETA: MAYBE AVERAGING THE PSF IS NOT OK FOR REDUCING SIZE
                    # if (Nc[i] == 'gds' or Nc[i] == 'gdn') and (F == 'f200w-clear' or F == 'f150w-clear'):
                    #     # Ensure all arrays are float to avoid integer division
                    #     psf = psf.astype("f4")
                    #     mask = mask.astype("f4")
                    #     sig = sig.astype("f4")
                    #     image = image.astype("f4")
                        
                    #     # Resize all arrays using block_reduce
                    #     resized_psf = block_reduce(psf, block_size=(2, 2), func=np.mean)
                    #     resized_mask = block_reduce(mask, block_size=(2, 2), func=np.mean)
                    #     resized_sig = block_reduce(sig, block_size=(2, 2), func=np.mean)
                    #     resized_image = block_reduce(image, block_size=(2, 2), func=np.mean)
                        
                    #     # Normalize if needed
                    #     psf = resized_psf[1:len(resized_psf), 1:len(resized_psf)] / np.sum(resized_psf[1:len(resized_psf), 1:len(resized_psf)])
                        
                    #     if len(resized_mask)%2 == 0:
                    #         mask = resized_mask[1:len(resized_mask), 1:len(resized_mask)].copy(order='C')
                    #         sig = resized_sig[1:len(resized_sig), 1:len(resized_sig)].copy(order='C')
                    #         image = resized_image[1:len(resized_image), 1:len(resized_image)].copy(order='C')
                    #     else:
                    #         mask = resized_mask.copy(order='C')
                    #         sig = resized_sig.copy(order='C')
                    #         image = resized_image.copy(order='C')

                    # else:
                    if (len(psf) - len(mask) > 0):
                        Np = len(psf) - len(mask)
                        if Np == 1:
                            lim1 = int(0.5*Np)+1
                        else: lim1 = int(0.5*Np)
                    else: 
                        Np = 0
                    psf = psf[lim1-1:len(psf)-lim1-1,lim1-1:len(psf)-lim1-1]/np.sum(psf[lim1-1:len(psf)-lim1-1, lim1-1:len(psf)-lim1-1])
                    print('size difference between psf and mask was', Np)
                    
                    pix_size = 0.04/3600    
                    if (Nc[i] == 'gds' or Nc[i] == 'gdn') and (F == 'f200w-clear' or F == 'f150w-clear'):
                            pix_size = pix_size/2.0

                    profile = 'sersic'
                    if prior_sel == False:
                        prior = autoprior(image = image, profile_type = profile, mask = mask, sky_type = 'none')
                        fitter_cur = FitSingle(
                                data = image,
                                rms = sig,
                                psf = psf,
                                prior = prior,
                                mask = mask,
                                loss_func = student_t_loss
                                )    
                    # Find priors for each band
                    else: 
                        sky_med, sky_std, n_pix = estimate_sky(image = image, mask = mask)
                        sky_med_unc = sky_std/np.sqrt(n_pix) 
                        custom_prior = PySersicSourcePrior(profile_type = 'sersic', sky_type= 'flat', sky_guess=sky_med, sky_guess_err= 2*sky_med_unc)
                        ## ------- Reasonable value for the flux --------- ##
                        if image.dtype != "f4":
                            image == image.astype("f4")
                        if sig.dtype != "f4":
                            sig == sig.astype("f4")
                        if mask.dtype != "f4":
                            mask == mask.astype("f4")
                        if psf.dtype != "f4":
                            psf == psf.astype("f4")
                        if any(dtype != 'f4' for dtype in [image.dtype, sig.dtype, mask.dtype, psf.dtype]):
                            continue
                        cat, seg = sep.extract(image, thresh=3., err = sig, segmentation_map= True,)
                        #########################################
                        xc = imsize[i][k]/0.04*0.5
                        yc = imsize[i][k]/0.04*0.5
                        # if (Nc[i] == 'gds' or Nc[i] == 'gdn') and (F == 'f200w-clear' or F == 'f150w-clear'):
                        #     xc = 2*xc
                        #     yc = 2*yc
                        # PRIOR SET
                        ## ------- Reasonable value for the flux --------- ##
                        if len(cat['a'])==0 or len(cat['flux'])==0:
                            continue
                        flux_guess = cat['flux'][0]
                        sem_maj_axis = cat['a'][0]
                        custom_prior.set_gaussian_prior('r_eff', sem_maj_axis, 0.3*sem_maj_axis)
                        custom_prior.set_gaussian_prior('flux', flux_guess, 0.3*flux_guess)
                        custom_prior.set_gaussian_prior('xc', xc, 2)
                        custom_prior.set_gaussian_prior('yc', yc, 2)
                        custom_prior.set_uniform_prior('n', 0.5, 9.)
                        custom_prior.set_uniform_prior('ellip', 0., 1.)
                        custom_prior.set_uniform_prior('theta', 0., 2*np.pi)
                        prior_dict = autoprior(image = image, profile_type = profile, mask = mask, sky_type = 'flat')
                        #########
                        rkey = jax.random.PRNGKey(5+3*k)
                        rkey,_ = jax.random.split(rkey, 2) # use different random number key for each run
                        ##
                        fitter_cur = FitSingle(
                                data = image,
                                rms = sig,
                                psf = psf,
                                prior = custom_prior,
                                mask = mask,
                                loss_func = student_t_loss
                                )
                    
                    # Parameters for each pairs of coordinates(source), given the field
                    print(f'Running fit - {Cname}')
                    try:
                        ind_res_cur = fitter_cur.estimate_posterior(method = 'svi-flow', rkey = rkey)
                        fitter_cur.svi_results.save_result(f'{asdf_store_name}/{Cname}.asdf')
                    except ValueError:
                        print(f'Source{Cname} got invalid StudentT distribution')
                        continue
                    except RuntimeError:
                        print('RunTime error: no suitable initial parameters')
                        continue
                    fitter_dict[f][k] = []
                    ind_res_dict[f][k] = []
                    ind_res_dict[f][k].append(ind_res_cur.retrieve_med_std())
                    fitter_dict[f][k].append(fitter_cur)
                ###
                    
        ###########################
        ####### ALL BANDS #########
        ## The following section is commented because pysersic does not deal with images having different pixel size.
        ##  Even if normalized to same pixel size, if each image in the sequence has different arcsec size per pixel, 
        ##                                                        the fit will be mislead (N pixels in 0.02" do not correspond to N pixels at 0.04") 
        ###########################
        # Uncomment the following block to allow for stacking in multiple filters and start fitting
        # for i in range(len(Nc)): # i->l
        #     for m in range(len(ra[i])): # k->m
        #         ### Verify if the source has been fitted and fit has been successful
        #         check_var = 0
        #         for f, filter_name in enumerate(aux_wv_list[i]):
        #             if filter_name not in ind_res_dict or not isinstance(ind_res_dict[filter_name], dict) or m not in ind_res_dict[filter_name]:
        #                 check_var = 1
        #         if check_var == 1:
        #             continue
        #         ### Continue(Skip source) if it did not
        #         Cname_all = Nc[i] + '_' + str(ra[i][m]) + '_' + str(dec[i][m])
        #         fig, axes = plt.subplots(1, 3, figsize = (10,3))
        #         fig.suptitle(Cname_all, fontsize = 10)
        #         ###
        #         for n, param in enumerate(['n', 'ellip', 'r_eff']):# j->n
        #             ax = axes[n]
        #             if param == 'r_eff':
        #                 med_ind = [ind_res_dict[f][m][0][param][0] for f in aux_wv_list[i]]
        #                 err_ind = [ind_res_dict[f][m][0][param][1] for f in aux_wv_list[i]]
        #                 plt.plot(wv_list[i], med_ind, color = 'red', ls = '-')
        #                 ax.errorbar(wv_list[i], med_ind, yerr = err_ind, fmt = 'o', color = 'k',
        #                             label = 'Ind.', ms = 8, capsize = 3, markeredgecolor = 'k', markerfacecolor = 'red', markeredgewidth = 1.1, ls='-')
        #             else:
        #                 med_ind = [ind_res_dict[f][m][0][param][0] for f in aux_wv_list[i]]
        #                 err_ind = [ind_res_dict[f][m][0][param][1] for f in aux_wv_list[i]]
        #                 ax.errorbar(wv_list[i], med_ind, yerr = err_ind, fmt = 'o', color = 'k',
        #                             label = 'Ind.', ms = 8, capsize = 3, markeredgecolor = 'k', markerfacecolor = 'red', markeredgewidth = 1.1, ls='-')
                        
        #             axes[0].legend()
        #             if param == 'n':
        #                 param_latex = r'$n$'
        #             elif param == 'ellip':
        #                 param_latex = r'$1-q$'
        #             elif param == 'r_eff':
        #                 param_latex = r'$R_e(\mathrm{kpc}$)'
        #             ax.set_title(param_latex, fontsize = 14)
        #             ax.set_xlabel(r'Obs. $\lambda\:\mathrm{(\mu m)}$')
                
        #         plt.tight_layout()
        #         plt.savefig(directory_name+'/'+Cname_all+'.pdf')
                
                
        #         wv_to_save = np.linspace(min(wv_list[i]), max(wv_list[i]), num = 50)
        #         print(f'Running fit - {Cname_all}')
        #         try: 
        #             MultiFitter = FitMultiBandPoly(fitter_list = [fitter_dict[F][m][0] for F in aux_wv_list[i]],
        #                                             wavelengths = wv_list[i],
        #                                             band_names = Nfilter[i],
        #                                             linked_params = ['n','ellip','r_eff'],
        #                                             const_params = ['xc','yc','theta'],
        #                                             wv_to_save = wv_to_save,
        #                                             poly_order = n_order)
        #         except ValueError:
        #             print(f'Source{Cname_all} got invalid StudentT distribution')
        #             continue
                    
        #         rkey = jax.random.PRNGKey(5+3*m)
        #         rkey, subkey = jax.random.split(rkey)
        #         print(f'Estimating posterior - {Cname_all}')
        #         try:
        #             multires = MultiFitter.estimate_posterior(method = 'svi-flow', rkey = rkey)
        #         except TypeError:
        #             print('Something wrong in concatenation?')
        #             continue
        #         except RuntimeError:
        #             print('RunTime error: no suitable initial parameters')
        #             continue
                
        #         #########
        #         link_params = [f'{param}_{b}' for b in Nfilter[i] for param in ['n','ellip','r_eff']] # Look at posteriors of "linked" parameters
        #         multi_res_dict = multires.retrieve_med_std()
        #         az.summary(multires.idata, var_names=link_params)
        #         #########
        #         if plot_fits == True:    
        #             fig, axes = plt.subplots(1,3, figsize = (10,3))
        #             fig.suptitle(Cname_all, fontsize = 10)
        #             for s,param in enumerate(['n','ellip','r_eff']):
        #                 ax = axes[s]
                        
        #                 med_ind = [ind_res_dict[f][m][0][param][0] for f in aux_wv_list[i]]
        #                 err_ind = [ind_res_dict[f][m][0][param][1] for f in aux_wv_list[i]]
                
        #                 med_multi = [multi_res_dict[f'{param}_{b}'][0] for b in Nfilter[i]]
        #                 err_multi = [multi_res_dict[f'{param}_{b}'][1] for b in Nfilter[i]]
                
        #                 ax.errorbar(wv_list[i], med_ind, yerr=err_ind, fmt = 'o', color = 'k', label = 'Ind. fit')
        #                 ax.errorbar(np.array(wv_list[i])+0.01, med_multi, yerr=err_multi, fmt = 'o', color = 'C0', label = 'Joint fit')
        #                 param_smooth = multires.idata.posterior[f'{param}_at_wv'].data.squeeze()
        #                 ax.plot(wv_to_save, param_smooth[:20].T, 'C0-', alpha = 0.2)
        #                 if param == 'n':
        #                     param_latex = r'$n$'
        #                 elif param == 'ellip':
        #                     param_latex = r'$1-q$'
        #                 elif param == 'r_eff':
        #                     param_latex = r'$R_e$(pixels)'
        #                 ax.set_title(param_latex, fontsize = 14)
        #                 ax.set_xlabel(r'Obs. $\lambda\:\mathrm{(\mu m)}$')
        #             axes[0].legend()
        #             plt.tight_layout()
        #             plt.savefig(directory_name+'/'+Cname_all+'.pdf')
        #         else:
        #             fig, axes = plt.subplots(1,3, figsize = (10,3))
        #             fig.suptitle(Cname_all, fontsize = 10)
        #             for s,param in enumerate(['n','ellip','r_eff']):
        #                 ax = axes[s]
        #                 med_ind = [ind_res_dict[f][m][0][param][0] for f in aux_wv_list[i]]
        #                 err_ind = [ind_res_dict[f][m][0][param][1] for f in aux_wv_list[i]]
                
        #                 med_multi = [multi_res_dict[f'{param}_{b}'][0] for b in Nfilter[i]]
        #                 err_multi = [multi_res_dict[f'{param}_{b}'][1] for b in Nfilter[i]]
                
        #                 ax.errorbar(wv_list[i], med_ind, yerr=err_ind, fmt = 'o', color = 'k', label = 'Ind. fit')
        #                 ax.errorbar(np.array(wv_list[i])+0.01, med_multi, yerr=err_multi, fmt = 'o', color = 'C0', label = 'Joint fit')
        #                 param_smooth = multires.idata.posterior[f'{param}_at_wv'].data.squeeze()
        #                 ax.plot(wv_to_save, param_smooth[:20].T, 'C0-', alpha = 0.2)
        #                 if param == 'n':
        #                     param_latex = r'$n$'
        #                 elif param == 'ellip':
        #                     param_latex = r'$1-q$'
        #                 elif param == 'r_eff':
        #                     param_latex = r'$R_e$(pixels)'
        #                 ax.set_title(param_latex, fontsize = 14)
        #                 ax.set_xlabel(r'Obs. $\lambda\:\mathrm{(\mu m)}$')
        #             axes[0].legend()
        #             plt.tight_layout()
        #             plt.savefig(directory_name+'/'+Cname_all+'.pdf')
        #             plt.close(fig)
                    
        # return ind_res_dict

# Definition of inputs

In [4]:
chi2 = 50

In [5]:
posVec = pd.read_csv(f'/work3/s240096/DTU_project/SED_selctedGal_{chi2}_cut9.csv')
RA = posVec['RA(deg)']
DEC = posVec['DEC(deg)']

In [6]:
############################
######### REMEMBER ######### 
# to check the location of PSF and zout/phot_corr files
N0 = 'https://s3.amazonaws.com/grizli-v2/JwstMosaics/v7/' # Specify path beforehand
# List of the fields
ra = {}
dec = {}
Nc = ['ceers-full', 'gds', 'primer-uds-north', 'primer-uds-south', 'primer-cosmos-east', 'primer-cosmos-west', 'gdn']
counter = 0
for i in range(len(Nc)):
    ra[Nc[i]] = []
    dec[Nc[i]] = []
    lst = os.listdir(f"/work3/s240096/DTU_project/SEDs_Chi{chi2}_{Nc[i]}/good_SED") # your directory path
    number_files = len(lst)
    for j in range(counter, counter+number_files):
        ra[Nc[i]].append(RA[j])
        dec[Nc[i]].append(DEC[j])
    #print(bool(ra[Nc[i]] == list(RA[counter:counter+number_files])))
    counter = counter + number_files
# # Only one version specified for each field. Since PSF is made for different versions
Nv = ['v7.2', 'v7.2', 'v7.2', 'v7.2', 'v7.0', 'v7.0', 'v7.3']
# # type list(). Filters are to be inserted keeping in mind that the first one wills erve as prior for the multi-band fit
Nf = ['f200w-clear', 'f444w-clear']
#Nf = ['f150w-clear', 'f200w-clear', 'f277w-clear', 'f356w-clear', 'f444w-clear']
Nfilter = [Nf, Nf, Nf, Nf, Nf, Nf, Nf] 
# To use more coordinates: cycle over a pre-selected vector. Also the following are list-type variables. Insert one list for each field.
CoRA = []
CoDEC = []
for i in range(len(Nc)):
    CoRA.append(ra[Nc[i]])
    CoDEC.append(dec[Nc[i]])
# Code is robust enough to deal with null IMSIZE, but it takes more time: in case of criticity, use it.
stand_imsize = 3 # in arcseconds
IMSIZE = []
for i in range(len(Nc)):
        IMSIZE.append(list(stand_imsize + np.zeros(len(ra[Nc[i]]))))
psf_path = '/work3/s240096/psf_cat/'
WSubList = [1.50, 2.00, 2.77, 3.56, 4.44]
wvList = [WSubList, WSubList, WSubList, WSubList, WSubList, WSubList, WSubList] #micrometers
# Load only science image(_sci.fits) for testing!
n = 2

# Compute sizes

In [1]:
#### Cutouts: mask, sigma, fits ####
# for K in range(1, len(Nc)):
#     Tsize = CalcSize.GetExpWht(N0, [Nc[K]], [Nv[K]], [Nfilter[K]], [CoRA[K]], [CoDEC[K]], [IMSIZE[K]])

In [2]:
##### Size computation: single-band ####
for K in range(len(Nc)):
    CalcSize.GetSize(N0, [Nc[K]], [Nv[K]], [Nfilter[K]], [CoRA[K]], [CoDEC[K]], [IMSIZE[K]],
                                 psf_path, quick_size = False, plot_im_mask_psf = False, plot_resid = False, prior_sel = False)

In [76]:
##### Parameters computation: multi-band ####
# for K in range(len(Nc)):
#     CalcSize.FitBands(N0, [Nc[K]], [Nv[K]], [Nfilter[K]], [CoRA[K]], [CoDEC[K]], [IMSIZE[K]],
#                                 psf_path, [wvList[K]], n, plot_fits = False, plot_resid = False, prior_sel = False)