In [None]:
from pathlib import Path
import numpy  as np
from astropy.table import Table
import matplotlib.pyplot as plt

In [None]:
south_cat_path = Path("/global/cfs/cdirs/desicollab/users/rongpu/data/ls_dr9.0_desi_photoz/rf/final_pz_specz_combined_south.fits")
north_cat_path = Path("/global/cfs/cdirs/desicollab/users/rongpu/data/ls_dr9.0_desi_photoz/rf/final_pz_specz_combined_north.fits")

In [None]:
# north_cat = Table.read(north_cat_path)
cat = Table.read(south_cat_path)

In [None]:
mask = (cat['NOBS_G']>=1) & (cat['NOBS_R']>=1) & (cat['NOBS_Z']>=1)
print(np.sum(mask), np.sum(~mask), np.sum(mask)/len(mask))

In [None]:
# photometric bands
bands = ['g', 'r', 'z', 'w1', 'w2']
bands_allcap = ['G', 'R', 'Z', 'W1', 'W2']

mag_max = 30
mag_fill = 100
mask = np.ones(len(cat), dtype=bool)
for band in bands_allcap:
    mask = (cat['FLUX_'+band]/cat['MW_TRANSMISSION_'+band]<10**(0.4*(22.5-mag_max)))
    cat['FLUX_'+band][mask] = 10**(0.4*(22.5-mag_fill)) * cat['MW_TRANSMISSION_'+band][mask]
for band in ['G', 'R', 'Z']:
    mask = (cat['FIBERFLUX_'+band]/cat['MW_TRANSMISSION_'+band]<10**(0.4*(22.5-mag_max)))
    cat['FIBERFLUX_'+band][mask] = 10**(0.4*(22.5-mag_fill)) * cat['MW_TRANSMISSION_'+band][mask]

In [None]:
cat['gmag'] = 22.5 - 2.5*np.log10(cat['FLUX_G']/cat['MW_TRANSMISSION_G'])
cat['rmag'] = 22.5 - 2.5*np.log10(cat['FLUX_R']/cat['MW_TRANSMISSION_R'])
cat['zmag'] = 22.5 - 2.5*np.log10(cat['FLUX_Z']/cat['MW_TRANSMISSION_Z'])
cat['w1mag'] = 22.5 - 2.5*np.log10(cat['FLUX_W1']/cat['MW_TRANSMISSION_W1'])
cat['w2mag'] = 22.5 - 2.5*np.log10(cat['FLUX_W2']/cat['MW_TRANSMISSION_W2'])

cat['gfibermag'] = 22.5 - 2.5*np.log10(cat['FIBERFLUX_G']/cat['MW_TRANSMISSION_G'])
cat['rfibermag'] = 22.5 - 2.5*np.log10(cat['FIBERFLUX_R']/cat['MW_TRANSMISSION_R'])
cat['zfibermag'] = 22.5 - 2.5*np.log10(cat['FIBERFLUX_Z']/cat['MW_TRANSMISSION_Z'])

In [None]:
mask = (cat['gmag']-cat['rmag']==0) | (cat['rmag']-cat['zmag']==0) | (cat['zmag']-cat['w1mag']==0)
print(np.sum(mask), 'objects have zero color in g-r, r-z or z-W1 but still kept')
mask = (cat['w1mag']-cat['w2mag']==0)
print(np.sum(mask), 'objects with zero color W1-W2 but still kept')

# axis ratio
e = np.array(np.sqrt(cat['SHAPE_E1']**2+cat['SHAPE_E2']**2))
q = (1+e)/(1-e)

# shape probability (definition of shape probability in Soo et al. 2017)
p = np.ones(len(cat))*0.5
# DCHISQ[:, 2] is DCHISQ_EXP; DCHISQ[:, 3] is DCHISQ_DEV
mask_chisq = (cat['DCHISQ'][:, 3]>0) & (cat['DCHISQ'][:, 2]>0)
p[mask_chisq] = cat['DCHISQ'][:, 3][mask_chisq]/(cat['DCHISQ'][:, 3]+cat['DCHISQ'][:, 2])[mask_chisq]



In [None]:
# np.unique(south_cat.columns['SURVEY'])
south_cat.columns

In [None]:
def prepare_phot(cat):
    for m in ["G","R", "Z", "W1", "W2"]:
        cat[f"MAG_{m}"] = 22.5 - 2.5*np.log10(np.clip(cat[f"FLUX_{m}"] / cat[f"MW_TRANSMISSION_{m}"], 1e-16, None))
        # cat[f"MAG_IVAR_{m}"] = 22.5 - 2.5*np.log10(np.clip(cat[f"FLUX_{m}"] / cat[f"MW_TRANSMISSION_{m}"], 1e-16, None))

#Do the same thing for fiber mag



# NOTES: 
Features rongpu used: r-band magnitude and fiber-magnitude, g − r, r − z, z − W1 and W1 − W2 colors, half-light radius,
aspect ratio (ratio between semiminor and semimajor axes), and a shape parameter that quantifies if a galaxy is better fitted by an exponential profile or a de Vaucouleurs profile.


- What if we train the model using Flux scales rather than mag scales?
- Do we keep north and south separate or use as an input flag
- Look at the histograms for the distributions, maybe use logistic scaling? What happens to calibration with log scaling
- What metric are we looking for? MSE of median vs CDE loss
- Distance based epistemic uncertainty?
- Try aperture photometry vs model photometry

- https://github.com/rongpu/desi-photoz/tree/master/dr9_desi

- https://desi.lbl.gov/DocDB/cgi-bin/private/RetrieveFile?docid=7584;filename=CWR_revised.pdf;version=2

In [None]:
plt.hist(south_cat["Z_SPEC"],bins=100,histtype="step")
plt.hist(north_cat["Z_SPEC"],bins=100,histtype="step")

plt.yscale("log")

In [None]:
from collections import deque

In [None]:
%%timeit
a = []
for i in range(8000000):
    a.append(i*3)

In [None]:
%%timeit
a = deque()
for i in range(8000000):
    a.append(i*3)