# Survey non-uniformity check pipeline:

This notebook outlines my pipeline which mainly uses RAIL to check for any given choice of photo-z pipeline, the effect of survey non-uniformity.

We will modify this notebook to make it into a RAIL pipeline; 
for this also c.f. `golden spike rail_pipelines` under the DESC directory making a pipeline class

In [None]:
# import dependences

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

from astropy.io import fits
import numpy as np

import healpy as hp
import pickle

import pzflow
from pzflow import Flow
from pzflow.bijectors import Chain, ShiftBounds, RollingSplineCoupling
from pzflow.examples import get_galaxy_data

In [None]:
import healpy as hp
import pickle

import pandas as pd
from collections import OrderedDict

import sys
sys.path.insert(0, '/global/homes/q/qhang/desc/notebooks_for_analysis/')
import spatial_var_functions as svf
import measure_properties_with_systematics as mp

In [None]:
#import RailStage stuff
from rail.core.data import TableHandle
from rail.core.stage import RailStage
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
import matplotlib
cmap = matplotlib.cm.get_cmap('plasma')

### Pre-defined functions:

In [None]:
# here put together a function to assign the error and save, then loop 10 times

def assign_pixels_to_gals(Ngal, pixels, random_seed=10):
    # set random seed
    np.random.seed(random_seed)
    
    npix = len(pixels)
    print('Average no. of gal per pix: ', Ngal/npix)
    
    # generate a uniform distribution between 0 to 1 for each galaxy
    rand = np.random.uniform(size=Ngal)
    
    pixel_index = np.digitize(rand, np.linspace(0,1, npix + 1))
    pixel_index -= 1

    assigned_pixels = pixels[pixel_index]
    return assigned_pixels


def assign_obs_cond_to_gals(assigned_pixels, obs_cond, mask, sys, bands):
    
    assigned_obs_cond = {}
    
    for key in sys:
        assigned_obs_cond[key] = {}
        for b in bands:
            temp = np.zeros(len(mask))
            temp[mask.astype(bool)] = obs_cond[key][b]
            assigned_obs_cond[key][b] = temp[assigned_pixels]
    
    return assigned_obs_cond
    

    
def get_semi_major_minor(data, scale=1):

    q = (1 - data['ellipticity'])/(1 + data['ellipticity'])
    ai = data['size']
    bi = ai*q

    ai = ai.to_numpy()*scale
    bi = bi.to_numpy()*scale
    
    return ai, bi


def compute_mag_err(mag, ai, bi, assigned_obs_cond, obs_cond, band='u', m5key='coaddm5_per_visit'):
    
    A_ratio = svf.get_area_ratio_auto(ai, bi, assigned_obs_cond['theta'][band])
    err = svf.compute_sigma_rand_sq(
        assigned_obs_cond[m5key][band], 
        mag, 
        obs_cond['gamma'][band],
        assigned_obs_cond['nvis'][band],  
        A_ratio=A_ratio
    )
    err = svf.compute_magerr_lsstmodel(
        err, 
        obs_cond['sigmasys'], 
        highSNR=True
    )
    magerr = err*obs_cond['calibration']['magerrscale'][band]
        
    return magerr


def assign_new_mag_magerr(data, magerr, ai, bi, assigned_obs_cond, obs_cond, bands, rng):

    totObsIndex = 1
    ObsMags = np.zeros((len(data), len(bands)))
    ObsMagsErr = np.zeros((len(data), len(bands)))
     
    for ii, b in enumerate(bands):

        nsr = magerr[b]
        mags = data[b].to_numpy()

        # calculate observed magnitudes
        fluxes = 10 ** (mags / -2.5)
        obsFluxes = fluxes * (1 + rng.normal(scale=nsr))
        
        with np.errstate(divide="ignore"):
            newmags = -2.5 * np.log10(np.clip(obsFluxes, 0, None))

        #index for selecting samples within the 
        ind = newmags<obs_cond['sigLim'][b]
        
        totObsIndex *= ind
        
        # new magnitudes:
        newmags[~ind] = np.nan
        ObsMags[:,ii] = np.copy(newmags)
        
        # new errors:
        
        mag = ObsMags[:,ii]
        ObsMagsErr[:,ii] = compute_mag_err(mag, ai, bi, assigned_obs_cond, 
                             obs_cond, band=b, m5key='coaddm5_per_visit')
    totObsIndex = totObsIndex.astype(bool)
    
    return totObsIndex, ObsMags, ObsMagsErr


def join_tables_and_save(data, ObsMags, ObsMagsErr, pixels, totObsIndex, outfile):

    df = data.copy()
    magDf = pd.DataFrame(
                ObsMags, columns=[f"ObsMag_{band}" for band in bands], index=data.index
            )
    errDf = pd.DataFrame(
                ObsMagsErr, columns=[f"ObsMagErr_{band}" for band in bands], index=data.index
            )
    pixDf = pd.DataFrame(
                pixels, columns = ['pixels'], index=data.index,
    )

    obsCatalog = pd.concat([df, magDf], axis=1)
    obsCatalog = pd.concat([obsCatalog, errDf], axis=1)
    obsCatalog = pd.concat([obsCatalog, pixDf],axis=1)

    #finally select the indices to use:
    savecat = obsCatalog.loc[totObsIndex, :]

    svf.dump_save(savecat, outfile)
    
    print(f"Saved: {outfile}.")
    
    
def obs_cond_pipeline(data, usepixels, random_seed, rng, obs_cond, mask, sys, bands, outfile,
                      m5key='coaddm5_per_visit'):
    
    Ngal = len(mock_tract)

    assigned_pixels = assign_pixels_to_gals(Ngal, usepixels, random_seed=random_seed)
    #print('flag1')

    assigned_obs_cond = assign_obs_cond_to_gals(assigned_pixels, obs_cond, mask, sys, bands)
    #print('flag2')
    
    # assign photo-z errors:
    scale = obs_cond['calibration']['abscale']
    ai, bi = get_semi_major_minor(data, scale=scale)
    
    magerr = {}
    for b in bands:
        mag = data[b].to_numpy()
        magerr[b] = compute_mag_err(mag, ai, bi, assigned_obs_cond, 
                                 obs_cond, band=b, m5key='coaddm5_per_visit')
    #print('flag3')

    # apply error to get new magnitudes, compute new magnitudes, 
    # and cut objects beyond magnitude limits:
    rng = np.random.default_rng(10)
    totObsIndex, ObsMags, ObsMagsErr = assign_new_mag_magerr(data, magerr, ai, bi, 
                                                             assigned_obs_cond, 
                                                             obs_cond, bands, rng)
    #print('flag4')

    join_tables_and_save(data, ObsMags, ObsMagsErr, assigned_pixels, totObsIndex, outfile)

In [None]:
# here load each set and pass on to FZBoost, extract zmode, 
# delete the output file (or re-write each time and delete later),
# save the zmode file
import tables_io

def convert_catalog_to_test_data(data, DS, bands):

    data2 = OrderedDict()

    key1 = 'ObsMagErr_'
    key2 = 'ObsMag_'

    for bb in bands:
        data2['mag_err_%s_lsst'%bb] = data[key1 + bb].to_numpy()
        data2['mag_%s_lsst'%bb] = data[key2 + bb].to_numpy()

    data2['redshift'] = data['redshift'].to_numpy()

    xtest_data = tables_io.convert(data2, tables_io.types.NUMPY_DICT)
    test_data = DS.add_data("test_data", xtest_data, TableHandle)
    
    return test_data

In [None]:
def get_nz_meanz(pz, truez, pzbins, nbootstrap, zlim = [0,2.0], bins=100):
    
    nztrue = {}
    meanztrue = np.zeros(len(pzbins)-1)
    stdmeanz = np.zeros(len(pzbins)-1)
    
    for ii in range(len(pzbins)-1):
        ind = (pz>= pzbins[ii])&(pz < pzbins[ii+1])
        ind = ind.flatten()
        cc = np.histogram(truez[ind], range=zlim, bins=bins)
        
        zz = (cc[1][1:] + cc[1][:-1])*0.5
        nztrue[ii] = np.c_[zz,cc[0]]
        
        # calculate true mean z
        #meanztrue[ii] = np.mean(truez[ind])
        meanztrue[ii] = np.sum(cc[0] * zz)/np.sum(cc[0])
        
        
        # stdmeanz using bootstrap method:
        sampholder = np.zeros(nbootstrap)
    
        data = truez[ind]
        for kk in range(nbootstrap):
            samp = np.random.choice(data, 
                            size=len(data),
                            replace=True)
            # repeat the operation 
            cc = np.histogram(samp, range=zlim, bins=bins)
            sampholder[kk] = np.sum(cc[0] * zz)/np.sum(cc[0])
            #sampholder[kk] = np.mean(samp)
        stdmeanz[ii] = np.std(sampholder)
        
    return nztrue, meanztrue, stdmeanz 

## Step 1: Load pre-trained flow

In [None]:
photFlow = Flow(file = "/export/donatello/qhang/nersc_local/main_galaxy_flow/flow.pzflow.pkl")
shapeFlow = Flow(file = "/export/donatello/qhang/nersc_local/conditional_galaxy_flow/flow.pzflow.pkl")

In [None]:
# fraction of sample with the i-band limit at Y5 is about 30%
Ngal_gen = int(160000/0.3)
print(Ngal_gen)

In [None]:
# first sample the redshifts and photometry
photoCat = photFlow.sample(Ngal_gen, seed=0)

# then add in the sizes and ellipticities
fullCat = shapeFlow.sample(conditions=photoCat, seed=0)

In [None]:
# now apply the cut and save:

sel = fullCat['i']<25
sel = sel & (fullCat['u']<25.7)
sel = sel & (fullCat['g']<27.0)
sel = sel & (fullCat['r']<27.1)
sel = sel & (fullCat['z']<25.7)
sel = sel & (fullCat['y']<24.5)

print(len(fullCat['i'][sel]))

# save
mock_tract = fullCat.loc[sel, :]

svf.dump_save(mock_tract, 'cosmoDC2_pzflow_sample_single_tract.pkl')

### Step 2: Load systematic maps

In [None]:
root = '/pscratch/sd/q/qhang/'
bands = ['u','g','r','i','z','y']

In [None]:
def get_wfd_DESI_overlap(scratch):
    fname = scratch + "rubin_baseline_v2/"
    fname += "wfd_footprint_nvisitcut_500_nside_128.fits"
    wfd_mask = hp.read_map(fname)
    
    fname = scratch + "rubin_baseline_v2/"
    fname += "DESI_footprint_completeness_mask_128.fits"

    desi_mask = hp.read_map(fname)

    desi_mask[desi_mask<=0]=0
    desi_mask[desi_mask>0]=1
    
    pix = np.arange(len(wfd_mask))
    overlap_mask = wfd_mask*desi_mask
    overlap_pix = pix[overlap_mask.astype(bool)]
    return wfd_mask, desi_mask, overlap_pix

wfd_mask, desi_mask, overlap_pix = get_wfd_DESI_overlap(root)
overlap_mask = wfd_mask*desi_mask

In [None]:
# here load the observing cnoditions (Y1/Y5):
# load the Y1 baseline maps
# opsim directory
Opsimdir = root + 'rubin_baseline_v2/MAF-1year/'

# Here load the median 5sigma depth map in each band:

metric_dict = {'theta':'Median_seeingFwhmEff',
               'coaddm5':'CoaddM5',
               'nvis':'Nvisits',}

mask = wfd_mask

obs_cond = {}

for key in metric_dict.keys():
    print(f'Loading {key}...')
    name = metric_dict[key]
    obs_cond[key] = {}
    for b in bands:
        fname = Opsimdir+'baseline_v2_0_10yrs_%s_%s_and_nightlt365_HEAL.fits'%(name, b)
        fin1=hp.read_map(fname)
        obs_cond[key][b] = fin1[mask.astype(bool)]

# converting coaddm5 into m5 for single visit
obs_cond['coaddm5_per_visit']={}
for ii, b in enumerate(bands):
    delta_mag = 2.5*np.log10(np.sqrt(obs_cond['nvis'][b]/1.))
    obs_cond['coaddm5_per_visit'][b] = obs_cond['coaddm5'][b] - delta_mag
    
# load combined coadd depth:
#print('Loading combined coaddm5...')
#name = 'CoaddM5'
#fname = Opsimdir+'minion_1016_dc2_%s_nightlt365_HEAL.fits'%(name)
#fin1=hp.read_map(fname)
#obs_cond['coaddm5']['comb'] = fin1[mask.astype(bool)]

In [None]:
obs_cond['gamma'] = {
    'u':0.038,
    'g':0.039,
    'r':0.039,
    'i':0.039,
    'z':0.039,
    'y':0.039,
}
obs_cond['sigmasys'] = 0

# add calibration
obs_cond['calibration'] = {}
obs_cond['calibration']['abscale']=1/2.5,
obs_cond['calibration']['magerrscale']={
    'u': 0.73,
    'g': 1.20,
    'r': 0.98, 
    'i': 1.10,
    'z': 1.10,
    'y': 1.15,
}

obs_cond['sigLim'] = {
    'u': 24.9, 
    'g': 26.2, 
    'r': 26.3, 
    'i': 24.1, 
    'z': 24.9, 
    'y': 23.7,
}

In [None]:
# here split into 10 sets of pixels in i-band coadd depth:
ranges=[24.6, 25.7]
nquantiles = 20
sysmap = np.zeros(len(mask))
sysmap[mask.astype(bool)] = obs_cond['coaddm5']['i']

# define the bins 
qtl = np.linspace(ranges[0], ranges[1], nquantiles+1)
added_range=False
selected_pix = mp.select_pixels_from_sysmap(sysmap, mask, qtl, added_range=added_range)
print(len(selected_pix))

# calculate the mean and median value of sysmap in each quantile
mean_sys = np.zeros(nquantiles)
median_sys = np.zeros(nquantiles)
for ii in range(nquantiles):
    pix = selected_pix[ii]
    mean_sys[ii] = np.mean(sysmap[pix])
    median_sys[ii] = np.median(sysmap[pix])
#print('mean sys: ', mean_sys)
#print('median sys: ', median_sys)
plt.plot(mean_sys, label='mean_sys')
plt.plot(median_sys, label='median_sys')
plt.legend()
plt.xlabel('bin number')
plt.ylabel('i-band coadd depth (Y1)')

In [None]:
# here load the PZflow testing sample:
# in this case, we use the signle tract sample 


mock_tract = svf.dump_load('cosmoDC2_pzflow_sample_single_tract.pkl')
mock_tract.head()

print("Number of obj in this sample: ", len(mock_tract))

In [None]:
# get obs cond
random_seed=10
rng = np.random.default_rng(10)
sys = ['coaddm5_per_visit', 'nvis', 'theta']

data = mock_tract

for q in range(len(selected_pix)):
    usepixels = selected_pix[q]

    outfile = '/pscratch/sd/q/qhang/PZflow-samples/baselinev2.0-test/y1/'
    outfile += 'cosmoDC2_pzflow_sample_obs_cal-coaddm5-i-qtl-%s.pkl'%q
    
    obs_cond_pipeline(data, usepixels, random_seed, rng, obs_cond, mask, sys, bands, outfile,
                      m5key='coaddm5_per_visit')

### Step 3: run photo-z estimation

#### Here is an example of FZboost

In [None]:
from rail.estimation.algos.flexzboost import Inform_FZBoost, FZBoost

fz_modelfile = 'FZB_test.pkl'
pzflex = FZBoost.make_stage(name='fzboost', hdf5_groupname='',
                            model='FZB_test.pkl')


In [None]:
#q=0

for q in range(1,len(selected_pix)):
    print("Working on qtl %d"%q)
    
    root = '/pscratch/sd/q/qhang/PZflow-samples/DC2-test/'
    fname = root + 'cosmoDC2_pzflow_sample_minion_1016_y1_obs_cal-coaddm5-i-qtl-%s.pkl'%q
    data = svf.dump_load(fname)

    test_data = convert_catalog_to_test_data(data, DS, bands)

    fzresults = pzflex.estimate(test_data)

    # obtain the mode:
    zgrid = np.linspace(0, 3., 301)
    fz_modes = fzresults().mode(grid=zgrid)

    # save:
    fname = root + 'cosmoDC2_pzflow_sample_minion_1016_y1_obs_cal-coaddm5-i-qtl-%s-zmode.pkl'%q
    svf.dump_save(fz_modes, fname)

#### Here is an example of BPZ

In [None]:
from rail.estimation.algos.bpz_lite import BPZ_lite

band_names = [
    'mag_u_lsst','mag_g_lsst','mag_r_lsst',
    'mag_i_lsst','mag_z_lsst','mag_y_lsst'  
]

band_err_names = [
    'mag_err_u_lsst','mag_err_g_lsst','mag_err_r_lsst',
    'mag_err_i_lsst','mag_err_z_lsst','mag_err_y_lsst'
]
prior_band='mag_i_lsst'


output = "newBPZ_test.hdf5"

# mag_limits change to Y1 limits:
estimate_bpz = BPZ_lite.make_stage(name='estimate_bpz', hdf5_groupname='', 
                                   #columns_file=inroot+'test_bpz.columns',
                                   #prior_file='CWW_HDFN_prior.pkl',
                                   nondetect_val=np.nan, #spectra_file='SED/CWWSB4.list',
                                   band_names=band_names,
                                   band_err_names=band_err_names,
                                   prior_band=prior_band,
                                   mag_limits = dict(mag_u_lsst=27.79,
                                                mag_g_lsst=29.04,
                                                mag_r_lsst=29.06,
                                                mag_i_lsst=28.62,
                                                mag_z_lsst=27.98,
                                                mag_y_lsst=27.05),
                                   output=output)

In [None]:
for q in range(len(selected_pix)):
    print("Working on qtl %d"%q)
    
    root = '/pscratch/sd/q/qhang/PZflow-samples/baselinev2.0-test/y1/'
    fname = root + 'cosmoDC2_pzflow_sample_obs_cal-coaddm5-i-qtl-%s.pkl'%q
    data = svf.dump_load(fname)

    test_data = convert_catalog_to_test_data(data, DS, bands)

    bpz_estimated = estimate_bpz.estimate(test_data)

    # obtain the mode:
    zmode = bpz_estimated().ancil['zmode']

    # save:
    fname = root + 'bpz/cosmoDC2_pzflow_sample_obs_cal-coaddm5-i-qtl-%s-zmode.pkl'%q
    svf.dump_save(zmode, fname)

### Step 4: Check the shift in mean redshifts

In [None]:
# finally load all in terms of pz bins and compare true n(z)

NZTRUE = {}
MEANZTRUE = {}
STDMEANZ = {}

for q in range(20):

    root = '/pscratch/sd/q/qhang/PZflow-samples/baselinev2.0-test/y1/'
    fname = root + 'bpz/cosmoDC2_pzflow_sample_obs_cal-coaddm5-i-qtl-%s-zmode.pkl'%q
    pz = svf.dump_load(fname)

    fname = root + 'cosmoDC2_pzflow_sample_obs_cal-coaddm5-i-qtl-%s.pkl'%q
    cat = svf.dump_load(fname)
    truez = cat['redshift'].to_numpy()

    pzbins = np.linspace(0.2,1.2,5+1)
    nbootstrap = 1000
    nztrue, meanztrue, stdmeanz = get_nz_meanz(pz, truez, pzbins, nbootstrap)

    # output these quantites:
    
    NZTRUE[q] = nztrue
    MEANZTRUE[q] = meanztrue
    STDMEANZ[q] = stdmeanz

In [None]:
mean = np.zeros(5)

# weight should be number of pixels in the qtl: len(selected_pix[q])

for ii in range(5):
    dist = 0
    for q in range(20):
        w = len(selected_pix[q])/sum(mask)
        dist += NZTRUE[q][ii][:,1]*w
    #compute mean
    mean[ii] = np.sum(dist * NZTRUE[q][ii][:,0])/np.sum(dist)
print(mean)

In [None]:
nquantiles = 20

fig,axarr=plt.subplots(2,5,figsize=[15,5],gridspec_kw={'height_ratios': [3, 1]})
for ii in range(5):
    plt.sca(axarr[0,ii])
    for q in range(nquantiles):
        colorlab = q/(nquantiles*1.2)
        dzz = NZTRUE[q][ii][1,0] - NZTRUE[q][ii][0,0]
        plt.plot(NZTRUE[q][ii][:,0], NZTRUE[q][ii][:,1]/np.sum(NZTRUE[q][ii][:,1])/dzz, 
                color=cmap(colorlab))
    #plt.xlim([0,2])
    plt.ylim([0,4.1])
    plt.plot([mean[ii], mean[ii]], [0, 4.1],'k')
    plt.text(0.6, 3.5, "%.2f < z(BPZ) < %.2f"%(pzbins[ii], pzbins[ii+1]))
    plt.yticks([])
    plt.xlabel("z (truth)")
    
    plt.sca(axarr[1, ii])
    for q in range(nquantiles):
        colorlab = q/(nquantiles*1.2)
        plt.errorbar(mean_sys[q], MEANZTRUE[q][ii]-mean[ii], yerr=STDMEANZ[q][ii],fmt='o',
                    color=cmap(colorlab))
    dz = 0.005*(1+mean[ii])
    
    plt.plot([24.6, 25.7], [0, 0], 'k-', alpha=0.5)
    plt.fill_between([24.6, 25.7], [-dz, -dz], 
                    [dz, dz],color='k',alpha=0.2)
    if ii==0:
        plt.ylabel("$z - \\langle z\\rangle$")
    if ii>0:
        plt.yticks([])
    plt.xlabel("$i$ (Coadd, 1-year)")
    plt.ylim([-0.015,0.015])
    plt.xlim([24.6, 25.7])

plt.tight_layout()
plt.saveifg('fig.png', bbox_inches='tight')