# TPz lite

_Authors: Andreia Dourado, Bruno Moraes_

_Adapted from Sam Schmidt example notebook: https://github.com/LSSTDESC/rail_tpz ._


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import rail
import qp
from rail.core.data import TableHandle, PqHandle, ModelHandle, QPHandle, DataHandle, Hdf5Handle
from rail.core.data import TableHandle
from rail.core.stage import RailStage

### 1. Reading the data

In [None]:
data = pd.read_csv("training_set.csv")

In [None]:
data

### 2. Analysing the training set

#### 2.1 Funtions

In [None]:
def plot_errors(catalog):
    bands = ['u', 'g', 'r', 'i', 'z', 'y']
    j=1
    plt.figure(figsize=[9,13])
    for i, band in enumerate(bands): 
        plt.subplot(3,2,j)
        mag = np.array(catalog[f'mag_{band}_lsst'])
        err = np.array(catalog[f'mag_err_{band}_lsst'])
        sn = 1/(10**(0.4*err)-1)
        plt.hexbin(mag,sn , None, mincnt=1, cmap='inferno', gridsize=[200,100], bins='log')
        cbar = plt.colorbar()
        plt.ylabel("S/R ", fontsize=13)
        plt.xlabel("mag "+band, fontsize=13)
        plt.ylim(0,100)
        plt.xlim(20,30)
        plt.axhline(10, color= 'red')
        plt.grid(True)
        j+=1
        plt.tight_layout()
        plt.savefig('errors_training_set.png', format='png')

In [None]:
def mag_histogram(catalog, title='DP0.2'):
    bands = ['u','g', 'r', 'i', 'z','y']
    colors = ['blue', 'green', 'orange','red','purple','gray']
    plt.figure(figsize=(9,13))
    bins = np.linspace(9, 37, 57)
    j=1
    for i, (band, color) in enumerate(zip(bands,colors)):
        plt.subplot(3,2,j)
        plt.hist(catalog[f'mag_{band}_lsst'], histtype='stepfilled', bins=bins, label=f'{band} band', alpha = 0.5,
                 edgecolor = "black", color = color)
        #plt.xlim(16,27)
        plt.yscale('log')
        plt.xlabel('mag',fontsize=13)
        plt.ylabel('counts',fontsize=13)
        plt.legend(loc=2)
        plt.grid(True)
        j+=1
    plt.suptitle(title)
    plt.savefig('mag_histogram_training_set.png', format='png')
    plt.show()

In [None]:
def redshift_hist(catalog):
    plt.hist(catalog['redshift'], bins=np.linspace(0,3,200))
    plt.savefig('redshift_training_set.png', format='png')
    plt.show()

In [None]:
def mag_color(catalog):
    bands = ['u', 'g', 'r', 'i', 'z','y']
    mag_diff = {}
    plt.figure(figsize=(9,13))
    i=1
    for band,_band in zip(bands, bands[1::]):
        plt.subplot(3,2,i)
        i+=1
        mag_diff_v = catalog[f'mag_{band}_lsst']-catalog[f'mag_{_band}_lsst']
        mag_v = catalog[f'mag_{band}_lsst']
        plt.hexbin(mag_v, mag_diff_v, None, mincnt=1, cmap='Reds', gridsize=[400,200], bins='log')
        plt.xlabel("mag "+band,fontsize=13)
        plt.ylabel(f"{band}-{_band}",fontsize=13)
        #plt.legend()
        plt.xlim(16,32)
        plt.ylim(-2,5)
        plt.grid(True)
        plt.tight_layout()
        plt.savefig('magColor_training_set.png', format='png')
    plt.show()

In [None]:
def color_color(catalog):
    bands = ['u', 'g', 'r', 'i', 'z','y']
    i=1
    plt.figure(figsize=(12,12))
    for index in range(len(bands)-2):
        plt.subplot(3,2,i)
        i+=1
        color = catalog[f'mag_{bands[index+1]}_lsst']
        next_color = catalog[f'mag_{bands[index+2]}_lsst']
        past_color = catalog[f'mag_{bands[index]}_lsst']
        plt.hexbin(past_color-color,color-next_color, None, mincnt=1, cmap='turbo', gridsize=[400,200], bins='log')
        plt.xlabel(f'{bands[index+1]}-{bands[index+2]}',fontsize=13)
        plt.ylabel(f'{bands[index]}-{bands[index+1]}',fontsize=13)
        cbar = plt.colorbar()
        plt.xlim(-1,3)
        plt.ylim(-1,3)
        plt.savefig('colorColor_training_set.png', format='png')
    plt.show()

In [None]:
def color_color_red(catalog,xlim=[-1,3],ylim=[-1,3]):
    bands = ['u', 'g', 'r', 'i', 'z','y']
    i=1
    plt.figure(figsize=(12,12))
    for index in range(len(bands)-2):
        plt.subplot(3,2,i)
        i+=1
        color = catalog[f'mag_{bands[index+1]}_lsst']
        next_color = catalog[f'mag_{bands[index+2]}_lsst']
        past_color = catalog[f'mag_{bands[index]}_lsst']
        plt.hexbin(past_color-color,color-next_color, C=catalog['redshift'], mincnt=1, cmap='turbo', gridsize=[400,200])
        plt.xlabel(f'{bands[index+1]}-{bands[index+2]}',fontsize=13)
        plt.ylabel(f'{bands[index]}-{bands[index+1]}',fontsize=13)
        cbar = plt.colorbar(label='redshift')
        plt.xlim(xlim[0],xlim[1])
        plt.ylim(ylim[0],ylim[1])
        plt.savefig('colorColorRed_training_set.png', format='png')
    plt.show()

In [None]:
def spatial_distribution(catalog):
    plt.scatter(catalog['ra'],catalog['dec'], c=catalog['redshift'],cmap='turbo',s=1)
    plt.xlabel('RA [deg]')
    plt.ylabel('Dec [deg]')
    plt.colorbar()
    plt.savefig('spatial_distribution.png', format='png')

#### 2.2 Plots

Spatial distribution

In [None]:
spatial_distribution(data)

Redshift distribution

In [None]:
redshift_hist(data)

Errors

In [None]:
plot_errors(data)

Magnitude distribution

In [None]:
mag_histogram(data)

Mag-color

In [None]:
mag_color(data)

Color-color

In [None]:
color_color(data)

In [None]:
color_color_red(data)

### 3. Applying cuts

#### 3.1 i < 25.3, according to LSST gold sample (https://www.lsst.org/sites/default/files/docs/sciencebook/SB_3.pdf)

In [None]:
data = data[data["mag_i_lsst"] < 25.3]
data

#### 3.2 i > 16, according to detection limit

In [None]:
data = data[data["mag_i_lsst"] >  16]
data

#### 3.3 Removing NaN values

In [None]:
data = data.dropna()
data

### 4. Run TPz

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import tables_io

In [None]:
from rail.estimation.algos.tpz_lite import TPZliteInformer

In [None]:
from rail.core.utils import RAILDIR

In [None]:
from rail.core.utils import find_rail_file

In [None]:
import h5py

Train: 50000 objects

In [None]:
training_csv = data.sample(50000,random_state=40)

To write the hdf5 file correctly, it was converted the dataframe to a dictionary with parameter _"orient='list'"_, so that _'list': dict like {column->[values]}_

Creating the dictionary:

In [None]:
dict_list = training_csv.to_dict(orient='list')


Creating the hdf5 file:

In [None]:
RAILDIR #this is the path to save the hdf5 files

In [None]:
hdf5_file_path = '/home/andreia.dourado/.conda/envs/tpz_lite/lib/python3.11/site-packages/train_data.hdf5'

Writing the dictionary values in the hdf5 file:

In [None]:
with h5py.File(hdf5_file_path, 'w') as hdf5_file:
    #photometry_group = hdf5_file.create_group('photometry')
    def write_dict(group, data):
        for key, value in data.items():
            if isinstance(value, dict):
                # If the value is a dictionary, create a subgroup and recurse
                subgroup = group.create_group(str(key))
                write_dict(subgroup, value)
            else:
                # Otherwise, write the key-value pair
                group[str(key)] = value

    write_dict(hdf5_file, dict_list)

In [None]:
datafile = os.path.join(RAILDIR,'/home/andreia.dourado/.conda/envs/tpz_lite/lib/python3.11/site-packages/train_data.hdf5')

Creating the training_data file:

In [None]:
training_data = DS.read_file("training_data", TableHandle, datafile)

In [None]:
print(training_data.data)

#### 4.1 Setting the parameters

In [None]:
bands = ["u", "g", "r", "i", "z", "y"]
new_err_dict = {}
attribute_list = []
for band in bands:
    attribute_list.append(f"mag_{band}_lsst")
    new_err_dict[f"mag_{band}_lsst"] = f"mag_err_{band}_lsst"
# redshift is also an attribute used in the training, but it does not have an associated
# error its entry in the err_dict should be set to "None"
new_err_dict["redshift"] = None

print(new_err_dict)

In [None]:
tpz_dict = dict(zmin=min(training_csv['redshift']), 
                zmax=max(training_csv['redshift']), 
                nzbins=301, 
                hdf5_groupname='photometry',
                use_atts=attribute_list,
                err_dict=new_err_dict,
                nrandom=3, 
                ntrees=8)

#### 4.2 Inform method

In [None]:
pz_train = TPZliteInformer.make_stage(name='inform_TPZ', model='demo_tpz.pkl', **tpz_dict)

In [None]:
%%time
pz_train.inform(training_data)

#### 4.3 Estimate stage

Selecting the data:

In [None]:
validation= data.drop(training_csv.index)

Writing hdf5, similary to training set:

In [None]:
dict_list = validation.to_dict(orient='list')

In [None]:
hdf5_file_path = ('/home/andreia.dourado/.conda/envs/tpz_lite/lib/python3.11/site-packages/test_data.hdf5')

In [None]:
with h5py.File(hdf5_file_path, 'w') as hdf5_file:
    photometry_group = hdf5_file.create_group('photometry')
    def write_dict(group, data):
        for key, value in data.items():
                group[str(key)] = value
    write_dict(photometry_group, dict_list)

Creating the test_data file:

In [None]:
testfile = os.path.join(RAILDIR,'/home/andreia.dourado/.conda/envs/tpz_lite/lib/python3.11/site-packages/test_data.hdf5')

In [None]:
test_data = DS.read_file("test_data", TableHandle, testfile)

Run:

In [None]:
from rail.estimation.algos.tpz_lite import TPZliteEstimator

In [None]:
test_dict = dict(hdf5_groupname='photometry')

In [None]:
test_runner = TPZliteEstimator.make_stage(name="test_tpz", output="TPZ_demo_output.hdf5",
                                          model=pz_train.get_handle('model'), **test_dict)

In [None]:
%%time
results = test_runner.estimate(test_data)

### 5. Plotting point estimates and an example PDF

In [None]:
sz = test_data()['photometry']['redshift']
zmode = results().ancil['zmode']

In [None]:
plt.figure(figsize=(8,8))
plt.scatter(sz,zmode, s=2,c='k')
plt.plot([0,3],[0,3],'r--')
plt.xlabel("redshift", fontsize=15)
plt.ylabel("TPZ mode", fontsize=15)

In [None]:
which=1355
fig, axs = plt.subplots()
results().plot_native(key=which,axes=axs, label=f"PDF for galaxy {which}")
axs.axvline(sz[which],c='r',ls='--', label="true redshift")
plt.legend(loc='upper right', fontsize=12)
axs.set_xlabel("redshift")

### 6. Metrics

In [None]:
def plot_metrics(zspec,
                 zphot,
                 maximum,
                 path_to_save='',
                 title=None,
                 initial=0):
    
    '''
    Function to plot Bias, Sigma_68, Out2σ, Out3σ given a spectroscopic and photometric redshift. 
    
    Args:
    
    zspec: Numpy array with the spectroscopic redshift.
    
    zphot: Numpy array with the photometric redshifts calculated. Same size as zspec.
    
    maximum: Float that indicates the redshift max of the plots.
    
    Kwargs:
    
    initial: Float that indicates the redshift min of the plots.
    
    
    
    '''
    
    
    
    
    
    
    bins = np.arange(initial, maximum, 0.1)
    points = bins+0.05
    fraction_outliers = []
    sigma68z = []
    sigmaz=[]
    meanz = []
    outliers_2 = []

    for index in range(len(bins) - 1):
        bin_lower = bins[index]
        bin_upper = bins[index + 1]
        
    
        values_r = zphot[(zphot >= bin_lower) & (zphot <= bin_upper)]
        values_s = zspec[(zphot >= bin_lower) & (zphot <= bin_upper)]
        
        

        deltabias = (values_r - values_s)
        mean_bias = np.mean(deltabias)  # Mean bias for each bin
        meanz.append(mean_bias)
        
        
    
        s = np.sort(np.abs(deltabias/(1+values_s)))  # Standard deviation (sigma) for each bin
        sigma68 = s[int(len(s)*0.68)]
        sigma68z.append(sigma68)
        
        
        
        sigma = (np.sum((values_r-values_s-mean_bias)**2)/len(values_r))**0.5
        sigmaz.append(sigma)
    
        # Calculate the fraction of outliers outside 3 sigma
        outliers = deltabias[np.abs(deltabias-mean_bias) > 3 * sigma]
        fraction_outlier = len(outliers) / len(deltabias)
        fraction_outliers.append(fraction_outlier)
        
        
    
        #2 sigma
        outliers2 = deltabias[np.abs(deltabias-mean_bias) > 2 * sigma]
        fraction_outlier2 = len(outliers2) / len(deltabias)
        outliers_2.append(fraction_outlier2)



    fig, axes = plt.subplots(4, 1, figsize=(8, 14), sharex=True)
    plt.subplots_adjust(hspace=0.001) 


    x_lim = (0, np.max(bins))

    # Subplot 1: Mean Bias
    axes[0].plot(points[:-1], meanz, 'bo-')
    axes[0].axhline(0, color='black', linestyle='--')
    axes[0].set_ylabel(r'$\Delta z$', fontsize=20)
    axes[0].set_xlim(x_lim)
    axes[0].set_ylim(-0.05,0.05)
    axes[0].tick_params(axis='both', labelsize=14)
    axes[0].grid(True)

    # Subplot 2: Sigma 68
    axes[1].plot(points[:-1], sigma68z, 'go-')
    axes[1].set_ylabel(r'$\sigma_{68}$', fontsize=20)
    axes[1].set_xlim(x_lim)
    axes[1].axhline(0.12, color='black', linestyle='--')
    axes[1].set_ylim(0,max(sigmaz)+0.01)
    axes[1].set_ylim(0, 0.03)
    axes[1].set_yticks(np.arange(0, 0.5, 0.05))
    axes[1].tick_params(axis='both', labelsize=14)
    axes[1].grid(True)


    # Subplot3: 2_outliers
    axes[2].plot(points[:-1],outliers_2,'o-',color='darkorange')
    #axes[2].set_xlabel(r'$Z_{phot}$', fontsize=20)
    axes[2].set_ylabel('out$_{2σ}$', fontsize=20)
    axes[2].set_xlim(x_lim)
    axes[2].set_ylim(0,0.12)
    axes[2].axhline(0.1, color='black', linestyle='--')
    axes[2].tick_params(axis='both', labelsize=14)
    axes[2].grid(True)
    #axes[2].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

    # Subplot 4: 3_outliers
    axes[3].plot(points[:-1], fraction_outliers, 'ro-')
    axes[3].set_xlabel(r'$Z_{spec}$', fontsize=20)
    axes[3].set_ylabel('out$_{3σ}$', fontsize=20)
    axes[3].set_xlim(x_lim)
    axes[3].set_ylim(0,0.02)
    axes[3].axhline(0.015, color='black', linestyle='--')
    axes[3].tick_params(axis='both', labelsize=14)
    axes[3].grid(True)
   

    plt.suptitle(title)
    plt.xlim(0,3)
    plt.tight_layout()

    if path_to_save != '':
        plt.savefig(f'{path_to_save}')
    

    plt.show()


In [None]:
zmode= np.squeeze(zmode)
plot_metrics(sz,zmode,max(zmode)-0.2,initial=0,title='metrics_cortes_melissa.png')

#### 6.1 PIT-QQ

In [None]:
from qp.ensemble import Ensemble
from matplotlib import gridspec
from qp import interp
from qp.metrics.pit import PIT

In [None]:
def plot_pit_qq(pdfs, zgrid, ztrue, bins=None, title=None, code=None,
                show_pit=True, show_qq=True,
                pit_out_rate=None, savefig=False) -> str:
    """Quantile-quantile plot
    Ancillary function to be used by class Metrics.

    Parameters
    ----------
    pit: `PIT` object
        class from metrics.py
    bins: `int`, optional
        number of PIT bins
        if None, use the same number of quantiles (sample.n_quant)
    title: `str`, optional
        if None, use formatted sample's name (sample.name)
    label: `str`, optional
        if None, use formatted code's name (sample.code)
    show_pit: `bool`, optional
        include PIT histogram (default=True)
    show_qq: `bool`, optional
        include QQ plot (default=True)
    pit_out_rate: `ndarray`, optional
        print metric value on the plot panel (default=None)
    savefig: `bool`, optional
        save plot in .png file (default=False)
    """

    if bins is None:
        bins = 100
    if title is None:
        title = ""

    if code is None:
        code = ""
        label = ""
    else:
        label = code + "\n"


    if pit_out_rate is not None:
        try:
            label += "PIT$_{out}$: "
            label += f"{float(pit_out_rate):.4f}"
        except:
            print("Unsupported format for pit_out_rate.")

    plt.figure(figsize=[4, 5])
    gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])
    ax0 = plt.subplot(gs[0])
    sample = Sample(pdfs, zgrid, ztrue)

    if show_qq:
        ax0.plot(sample.qq[0], sample.qq[1], c='r',
                 linestyle='-', linewidth=3, label=label)
        ax0.plot([0, 1], [0, 1], color='k', linestyle='--', linewidth=2)
        ax0.set_ylabel("Q$_{data}$", fontsize=18)
        plt.ylim(-0.001, 1.001)
    plt.xlim(-0.001, 1.001)
    plt.title(title)
    if show_pit:
        fzdata = Ensemble(interp, data=dict(xvals=zgrid, yvals=pdfs))
        pitobj = PIT(fzdata, ztrue)
        pit_vals = np.array(pitobj.pit_samps)
        pit_out_rate = pitobj.evaluate_PIT_outlier_rate()

        try:
            y_uni = float(len(pit_vals)) / float(bins)
        except:
            y_uni = float(len(pit_vals)) / float(len(bins))
        if not show_qq:
            ax0.hist(pit_vals, bins=bins, alpha=0.7, label=label)
            ax0.set_ylabel('Number')
            ax0.hlines(y_uni, xmin=0, xmax=1, color='k')
            plt.ylim(0, )  # -0.001, 1.001)
        else:
            ax1 = ax0.twinx()
            ax1.hist(pit_vals, bins=bins, alpha=0.7)
            ax1.set_ylabel('Number')
            ax1.hlines(y_uni, xmin=0, xmax=1, color='k')
    leg = ax0.legend(handlelength=0, handletextpad=0, fancybox=True)
    for item in leg.legendHandles:
        item.set_visible(False)
    if show_qq:
        ax2 = plt.subplot(gs[1])
        ax2.plot(sample.qq[0], (sample.qq[1] - sample.qq[0]), c='r', linestyle='-', linewidth=3)
        plt.ylabel("$\Delta$Q", fontsize=18)
        ax2.plot([0, 1], [0, 0], color='k', linestyle='--', linewidth=2)
        plt.xlim(-0.001, 1.001)
        plt.ylim(np.min([-0.12, np.min(sample.qq[1] - sample.qq[0]) * 1.05]),
                 np.max([0.12, np.max(sample.qq[1] - sample.qq[0]) * 1.05]))
    if show_pit:
        if show_qq:
            plt.xlabel("Q$_{theory}$ / PIT Value", fontsize=18)
        else:
            plt.xlabel("PIT Value", fontsize=18)
    else:
        if show_qq:
            plt.xlabel("Q$_{theory}$", fontsize=18)
    if savefig:
        fig_filename = str("plot_pit_qq_" +
                           f"{(code).replace(' ', '_')}.png")
        plt.savefig(fig_filename)
    else:
        fig_filename = None

    return fig_filename


class Sample(Ensemble):
    """ Expand qp.Ensemble to append true redshifts
    array, metadata, and specific plots. """

    def __init__(self, pdfs, zgrid, ztrue, photoz_mode=None, code="", name="", n_quant=100):
        """Class constructor

        Parameters
        ----------
        pdfs: `ndarray`
            photo-z PDFs array, shape=(Ngals, Nbins)
        zgrid: `ndarray`
            PDF bins centers, shape=(Nbins,)
        ztrue: `ndarray`
            true redshifts, shape=(Ngals,)
        photoz_mode: `ndarray`
            photo-z (PDF mode), shape=(Ngals,)
        code: `str`, (optional)
            algorithm name (for plot legends)
        name: `str`, (optional)
            sample name (for plot legends)
        """

        super().__init__(interp, data=dict(xvals=zgrid, yvals=pdfs))
        self._pdfs = pdfs
        self._zgrid = zgrid
        self._ztrue = ztrue
        self._photoz_mode = photoz_mode
        self._code = code
        self._name = name
        self._n_quant = n_quant
        self._pit = None
        self._qq = None


    @property
    def code(self):
        """Photo-z code/algorithm name"""
        return self._code

    @property
    def name(self):
        """Sample name"""
        return self._name

    @property
    def ztrue(self):
        """True redshifts array"""
        return self._ztrue

    @property
    def zgrid(self):
        """Redshift grid (binning)"""
        return self._zgrid

    @property
    def photoz_mode(self):
        """Photo-z (mode) array"""
        return self._photoz_mode

    @property
    def n_quant(self):
        return self._n_quant

    @property
    def pit(self):
        if self._pit is None:
            pit_array = np.array([self[i].cdf(self.ztrue[i])[0][0] for i in range(len(self))]) 
            self._pit = pit_array
        return self._pit

    @property
    def qq(self, n_quant=100):
        q_theory = np.linspace(0., 1., n_quant)
        q_data = np.quantile(self.pit, q_theory)
        self._qq = (q_theory, q_data)
        return self._qq

    def __len__(self):
        if len(self._ztrue) != len(self._pdfs):
            raise ValueError("Number of pdfs and true redshifts do not match!!!")
        return len(self._ztrue)

    def __str__(self):
        code_str = f'Algorithm: {self._code}'
        name_str = f'Sample: {self._name}'
        line_str = '-' * (max(len(code_str), len(name_str)))
        text = str(line_str + '\n' +
                   name_str + '\n' +
                   code_str + '\n' +
                   line_str + '\n' +
                   f'{len(self)} PDFs with {len(self.zgrid)} probabilities each \n' +
                   f'qp representation: {self.gen_class.name} \n' +
                   f'z grid: {len(self.zgrid)} z values from {np.min(self.zgrid)} to {np.max(self.zgrid)} inclusive')
        return text

In [None]:
results = DS.read_file('pdfs_data', QPHandle, f'TPZ_demo_output.hdf5')

In [None]:
catalog= DS.read_file('test', Hdf5Handle,'/home/andreia.dourado/.conda/envs/tpz_lite/lib/python3.11/site-packages/test_data.hdf5')

In [None]:
ztrue=catalog.data['photometry']['redshift']
pdfs = results().build_tables()['data']['yvals']
z_max=results().build_tables()['meta']['xvals'][0][-1]
zgrid=(np.linspace(0,z_max,301))

In [None]:
plot_pit_qq(pdfs,zgrid,ztrue,title='TPz_DP02')