In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib_inline.backend_inline
from astropy.table import Table
from astropy.io import fits
from pathlib import Path
import numpy as np
import eazy
import pickle
import gc
import os
import psutil
import time

p = psutil.Process(os.getpid())
# p.nice(-1)

matplotlib_inline.backend_inline.set_matplotlib_formats('retina')

WD = Path('/data1/hbahk/spherex-photoz/spherex-challenge/7ds_challenge')
os.chdir(WD)
TEMPDIR = WD.parent / 'BROWN_COSMOS'

RECALCULATE_TEMPFILT = False

In [None]:
params = {}
params['FILTERS_RES']       = "FILTER.RES+7DTreduced.res"
params['TEMPLATES_FILE']    = "BROWN_COSMOS.template"
params['TEMP_ERR_FILE']     = "../template_error_cosmos2020.txt"
params['TEMP_ERR_A2']       = 0.0 # nullify the template error

params['SYS_ERR']           = 0.01

params['MW_EBV']            = 0.016577
params['CAT_HAS_EXTCORR']   = False

params['CATALOG_FILE']      = "input/input_RIS.csv" 

dir_output = WD/"output"/"output_RIS"

if not dir_output.exists():
    dir_output.mkdir(parents=True)
params['OUTPUT_DIRECTORY']  = dir_output.as_posix()
params['MAIN_OUTPUT_FILE']  = "output"

params['APPLY_PRIOR']       = True
params['PRIOR_FILE']        = "prior_R_zmax7.dat"
params['PRIOR_FILTER']      = 429  # 7DT band corresponding to 6500 Angstrom
params['PRIOR_ABZP']        = 23.9
params['PRIOR_FLOOR']       = 0.01

params['FIX_ZSPEC']         = False
params['Z_MIN']             = 0.015
params['Z_MAX']             = 5.8
params['Z_STEP']            = 0.01
params['Z_STEP_TYPE']       = 1

translate_filename = 'dummy.translate'

params_RIS = params.copy()

In [None]:
params['CATALOG_FILE']      = "input/input_WFS.csv"
dir_output = WD/"output"/"output_WFS"
dir_output.mkdir(parents=True, exist_ok=True)
params['OUTPUT_DIRECTORY']  = dir_output.as_posix()
params_WFS = params.copy()

params['CATALOG_FILE']      = "input/input_IMS.csv"
dir_output = WD/"output"/"output_IMS"
dir_output.mkdir(parents=True, exist_ok=True)
params['OUTPUT_DIRECTORY']  = dir_output.as_posix()
params_IMS = params.copy()

In [None]:
def get_pit(zspec, zgrid, lnp, trdz):
    """
    PIT function for evaluating the calibration of p(z), 
    as described in Tanaka (2017).
    """
    zspec_grid = np.dot(zspec[:,None], np.ones_like(zgrid)[None,:])
    zlim = zspec_grid >= zgrid
    pit = np.dot(np.exp(lnp)*zlim, trdz)
    
    return pit


def get_crps(zspec, zgrid, lnp, trdz):
    """
    CRPS function for evaluating the calibration of p(z), 
    as described in Tanaka (2017).
    """
    from tqdm import trange
    crps= np.empty(len(lnp), dtype=float)
    for i in trange(len(lnp), desc='Calculating CRPS'):
        zzgrid = np.dot(zgrid[:,None], np.ones_like(zgrid)[None,:])
        lnpgrid = np.dot(np.ones_like(zgrid)[:,None], lnp[i][None,:])
        zlim = zzgrid >= zgrid
        h = np.zeros_like(zgrid)
        h[zgrid >= zspec[i]] = 1
        pit = np.dot(np.exp(lnpgrid)*zlim, trdz)
        crps[i] = np.dot((pit-h)**2, trdz)
    
    return crps


def fit_and_save_result_prior(params, tempfilt):
    start = time.time()
    dir_output = Path(params['OUTPUT_DIRECTORY'])

    for i in range(17):
        print(f'Fitting {i}th batch')
        start_id = i * 10000
        end_id = (i+1) * 10000
        
        ez = eazy.photoz.PhotoZ(param_file=None, translate_file=translate_filename,
                            zeropoint_file=None, params=params, tempfilt=tempfilt, )
        
        if end_id > len(ez.idx):
            end_id = len(ez.idx)

        ez.ZML_WITH_PRIOR = True
        ez.fit_catalog(ez.idx[start_id:end_id], n_proc=16, prior=True)
        ez.ZML_WITH_PRIOR = True
        ez.fit_at_zbest(prior=True, nproc=1)
        
        try:
            zlimits = ez.pz_percentiles(percentiles=[2.5,16,50,84,97.5],
                                            oversample=5)
        except:
            print('Couldn\'t compute pz_percentiles')
            zlimits = np.zeros((ez.NOBJ, 5), dtype=ez.ARRAY_DTYPE) - 1
            
        tab = Table()
        tab['id'] = ez.OBJID
        tab['z_phot'] = ez.zbest
        tab['z_phot_chi2'] = ez.chi2_best
        tab['z025'] = zlimits[:,0]
        tab['z160'] = zlimits[:,1]
        tab['z500'] = zlimits[:,2]
        tab['z840'] = zlimits[:,3]
        tab['z975'] = zlimits[:,4]
        
        pit = ez.PIT(ez.ZSPEC)
        tab['pit'] = pit
        
        crps = get_crps(ez.ZSPEC[start_id:end_id], ez.zgrid, ez.lnp[start_id:end_id], ez.trdz)
        tab['crps'] = np.empty(len(tab), dtype=float)
        tab['crps'][start_id:end_id] = crps
        
        tab[start_id:end_id].write(dir_output/f'output{i:02d}.fits', overwrite=True)
        phdu = fits.PrimaryHDU(data=ez.lnp[start_id:end_id])
        gridhdu = fits.ImageHDU(data=ez.zgrid)
        hdul = fits.HDUList([phdu, gridhdu])
        hdul.writeto(dir_output/f'lnp{i:02d}.fits', overwrite=True)
        del ez, tab, phdu, gridhdu, hdul
        gc.collect()
    
    incatpath = Path(params['CATALOG_FILE'])
    base = Table.read(incatpath)
    
    colnames = ['z_phot', 'z_phot_chi2', 'z160', 'z840', 'id']
    
    for i in range(17):
        start_id = i*10000
        end_id = (i+1)*10000 if i < 16 else len(base)
        
        outtab = Table.read(dir_output/f'output{i:02d}.fits')
        for label in colnames:
            base[label][start_id:end_id] = outtab[label]
    
    base.write(dir_output/'result.fits', overwrite=True)
    
    end = time.time()
    time_taken_hms = time.strftime('%H:%M:%S', time.gmtime(end-start))
    print(f'Finished in {time_taken_hms}')

In [None]:
with open('tempfilt_nored_001.pickle', 'rb') as rfile:
    tempfilt = pickle.load(rfile)

In [None]:
fit_and_save_result_prior(params_IMS, tempfilt)

In [None]:
fit_and_save_result_prior(params_WFS, tempfilt)

In [None]:
fit_and_save_result_prior(params_RIS, tempfilt)