# Análise dos resultados - Métricas

_Autores: Andreia Dourado, Bruno Moraes_

_Adaptado dos notebooks(https://github.com/LSSTDESC/rail_tpz) and Demo: RAIL Evaluation notebook (https://rail-hub.readthedocs.io/projects/rail-notebooks/en/latest/rendered/evaluation_examples/Evaluation_Demo.html)_

__Descrição: Análise das métricas para os resultados gerados na etapa Estimate para o TPZ.__

### 1. Importando as bibliotecas:

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import rail
import qp
from rail.core.data import TableHandle, PqHandle, ModelHandle, QPHandle, DataHandle, Hdf5Handle
from rail.core.stage import RailStage

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
from qp.ensemble import Ensemble
from matplotlib import gridspec
from qp import interp
from qp.metrics.pit import PIT
from rail.evaluation.metrics.cdeloss import *
from rail.evaluation.evaluator import OldEvaluator
from utils import plot_pit_qq, ks_plot
import os
from rail.estimation.algos.naive_stack import NaiveStackSummarizer
from scipy.interpolate import UnivariateSpline

%matplotlib inline
%reload_ext autoreload
%autoreload 

### 2. Leitura dos arquivos

##### A célula abaixo é apenas para facilitar a transicão entre os meus diretórios

In [None]:
sigma = 5
random = 12
leaf = 2
trees = 9

#### 2.1 Arquivo de teste utilizado no estimate:

In [None]:
ztrue_file=f'/lustre/t0/scratch/users/andreia.dourado/TCC/dp02/truth/{sigma}sigma/runs/test_file_dp02_truth_lsst_error_model_{sigma}sigma.hdf5'

In [None]:
ztrue_data = DS.read_file('ztrue_data', TableHandle, ztrue_file)

In [None]:
len(ztrue_data.data['photometry']['mag_g'])

#### 2.2 Arquivo de output gerado no estimate:

In [None]:
pdfs_file=f'/lustre/t0/scratch/users/andreia.dourado/TCC/dp02/truth/{sigma}sigma/runs/output_tpz_dp02_truth_lsst_error_model_{sigma}sigma_{leaf}leaf.hdf5'
pdfs_file

In [None]:
tpzdata = DS.read_file('pdfs_data', QPHandle, pdfs_file)

#### 2.3 Lendo os valores de redshift true e gerados no estimate:

In [None]:
ztrue = ztrue_data()['photometry']['redshift']
zgrid = tpzdata().metadata()['xvals'].ravel()
photoz_mode = tpzdata().mode(grid=zgrid)
z_mode= np.squeeze(photoz_mode)

In [None]:
truth = DS.add_data('truth', ztrue_data(), TableHandle)
ensemble = DS.add_data('ensemble', tpzdata(), QPHandle)

In [None]:
DS 

In [None]:
len(z_mode)

In [None]:
len(ztrue)

### 3. Métricas

__Caminho para salvar as imagens:__

In [None]:
path = f'/lustre/t0/scratch/users/andreia.dourado/TCC/dp02/truth/{sigma}sigma/metrics/{leaf}leaf_'

#### 3.1 Estimativa pontual

In [None]:
def plot_scatter(zphot,
                 ztrue,
                 zmin=0,
                 zmax=3,
                 bins=150,
                 cmap='viridis',
                 line_color='red',
                 line_width=0.2,
                 title='$z_{true}$ vs $z_{phot}$',
                 xlabel='z$_{true}$',
                 ylabel='z$_{phot}$', 
                 fontsize_title=18,
                 fontsize_labels=15,
                 path_to_save=''):
    """
    Plot a histogram of photometric redshift vs true redshift with a diagonal line.

    Parameters
    ----------
    zphot : array-like
        Array of photometric redshifts.
    ztrue : array-like
        Array of true (spectroscopic) redshifts.
    zmin : float, optional
        Minimum redshift value for the plot axes.
    zmax : float, optional
        Maximum redshift value for the plot axes.
    bins : int, optional
        Number of bins for the histogram.
    cmap : str, optional
        Colormap to be used for the histogram.
    line_color : str, optional
        Color of the diagonal line.
    line_width : float, optional
        Width of the diagonal line.
    title : str, optional
        Title of the plot.
    xlabel : str, optional
        Label for the x-axis.
    ylabel : str, optional
        Label for the y-axis.
    fontsize_title : int, optional
        Font size for the title.
    fontsize_labels : int, optional
        Font size for the x and y labels.
    path_to_save : str, optional
        Path to save the plot image.

    Returns
    -------
    None
    """
    sns.histplot(x=ztrue, y=zphot, bins=bins, cmap=cmap)
    plt.plot([0,3], [0,3], color=line_color, linewidth=line_width)
    plt.xlim(zmin, zmax)
    plt.ylim(zmin, zmax)
    plt.xlabel(xlabel, fontsize=fontsize_labels)
    plt.ylabel(ylabel, fontsize=fontsize_labels)
    plt.title(title, fontsize=fontsize_title)
    
    
    plt.savefig(f'{path}point_estimate.png')
    
    plt.show()

In [None]:
plot_scatter(z_mode,ztrue)

#### 3.2. PDF individual

In [None]:
which=1355
fig, axs = plt.subplots()
tpzdata().plot_native(key=which,axes=axs, label=f"PDF for galaxy {which}")
axs.axvline(ztrue[which],c='r',ls='--', label="true redshift")
plt.legend(loc='upper right', fontsize=12)
axs.set_xlabel("redshift")
#plt.savefig('example_pdf_NaN.png')

#### 3.3 Métricas básicas

In [None]:
def plot_metrics(zspec,
                 zphot,
                 maximum,
                 path_to_save='',
                 title=None,
                 initial=0):
    
    '''
    Function to plot Bias, Sigma_68, Out2σ, Out3σ given a spectroscopic and photometric redshift. 
    
    Args:
    
    zspec: Numpy array with the spectroscopic redshift.
    
    zphot: Numpy array with the photometric redshifts calculated. Same size as zspec.
    
    maximum: Float that indicates the redshift max of the plots.
    
    Kwargs:
    
    initial: Float that indicates the redshift min of the plots.
    
    
    
    '''
    
    
    
    
    
    
    bins = np.arange(initial, maximum, 0.1)
    points = bins+0.05
    fraction_outliers = []
    sigma68z = []
    sigmaz=[]
    meanz = []
    outliers_2 = []

    for index in range(len(bins) - 1):
        bin_lower = bins[index]
        bin_upper = bins[index + 1]
        
    
        values_r = zphot[(zphot >= bin_lower) & (zphot <= bin_upper)]
        values_s = zspec[(zphot >= bin_lower) & (zphot <= bin_upper)]
        
        

        deltabias = (values_r - values_s)
        mean_bias = np.mean(deltabias)  # Mean bias for each bin
        meanz.append(mean_bias)
        
        
    
        s = np.sort(np.abs(deltabias/(1+values_s)))# Standard deviation (sigma) for each bin
        #print(s)
        sigma68 = s[int(len(s)*0.68)]
        sigma68z.append(sigma68)
        
        
        
        sigma = (np.sum((values_r-values_s-mean_bias)**2)/len(values_r))**0.5
        sigmaz.append(sigma)
    
        # Calculate the fraction of outliers outside 3 sigma
        outliers = deltabias[np.abs(deltabias-mean_bias) > 3 * sigma]
        fraction_outlier = len(outliers) / len(deltabias)
        fraction_outliers.append(fraction_outlier)
        
        
    
        #2 sigma
        outliers2 = deltabias[np.abs(deltabias-mean_bias) > 2 * sigma]
        fraction_outlier2 = len(outliers2) / len(deltabias)
        outliers_2.append(fraction_outlier2)



    fig, axes = plt.subplots(4, 1, figsize=(8, 14), sharex=True)
    plt.subplots_adjust(hspace=0.001) 

    axes[1]
    x_lim = (0, np.max(bins))

    # Subplot 1: Mean Bias
    axes[0].plot(points[:-1], meanz, 'bo-')
    axes[0].axhline(0, color='black', linestyle='--')
    axes[0].set_ylabel(r'$\Delta z$', fontsize=20)
    axes[0].set_xlim(x_lim)
    #axes[0].set_ylim(-0.05,0.05)
    axes[0].tick_params(axis='both', labelsize=14)
    axes[0].grid(True)

    # Subplot 2: Sigma 68
    axes[1].plot(points[:-1], sigma68z, 'go-')
    axes[1].set_ylabel(r'$\sigma_{68}$', fontsize=20)
    axes[1].set_xlim(x_lim)
    axes[1].axhline(0.12, color='black', linestyle='--')
    #axes[1].set_ylim(0,max(sigmaz)+0.01)
    #axes[1].set_ylim(0, 0.03)
    axes[1].set_yticks(np.arange(0, 0.5, 0.05))
    axes[1].tick_params(axis='both', labelsize=14)
    axes[1].set_xlim(x_lim)
    axes[1].axhline(0.12, color='black', linestyle='--')
    #axes[1].set_ylim(0,max(sigmaz)+0.01)
    #axes[1].set_ylim(0, 0.03)
    axes[1].set_yticks(np.arange(0, 0.5, 0.05))
    axes[1].tick_params(axis='both', labelsize=14)
    axes[1].grid(True)


    # Subplot3: 2_outliers
    axes[2].plot(points[:-1],outliers_2,'o-',color='darkorange')
    #axes[2].set_xlabel(r'$Z_{phot}$', fontsize=20)
    axes[2].set_ylabel('out$_{2σ}$', fontsize=20)
    axes[2].set_xlim(x_lim)
    #axes[2].set_ylim(0,0.12)
    axes[2].axhline(0.1, color='black', linestyle='--')
    axes[2].tick_params(axis='both', labelsize=14)
    axes[2].grid(True)
    #axes[2].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

    # Subplot 4: 3_outliers
    axes[3].plot(points[:-1], fraction_outliers, 'ro-')
    axes[3].set_xlabel(r'$Z_{spec}$', fontsize=20)
    axes[3].set_ylabel('out$_{3σ}$', fontsize=20)
    axes[3].set_xlim(x_lim)
    #axes[3].set_ylim(0,0.12)
    axes[3].axhline(0.1, color='black', linestyle='--')
    axes[3].tick_params(axis='both', labelsize=14)
    axes[3].grid(True)
   

    plt.suptitle(title)
    plt.xlim(0,3)
    plt.tight_layout()

    #if path_to_save != '':
    plt.savefig(f'{path}metrics.png')
    

    #plt.show()

In [None]:
plot_metrics(ztrue,z_mode,max(z_mode)-0.2,initial=0)

#### 3.4 PIT QQ

In [None]:
pitobj = PIT(tpzdata(), ztrue)
quant_ens = pitobj.pit
metamets = pitobj.calculate_pit_meta_metrics()

In [None]:
metamets

In [None]:
pit_vals = np.array(pitobj.pit_samps)
pit_vals

In [None]:
pit_out_rate = metamets['outlier_rate']
print(f"PIT outlier rate of this sample: {pit_out_rate:.6f}")
pit_out_rate = pitobj.evaluate_PIT_outlier_rate()
print(f"PIT outlier rate of this sample: {pit_out_rate:.6f}")

In [None]:
pdfs = tpzdata.data.objdata()['yvals']

In [None]:
plt.figure(dpi=300) #qualidade de impressão 
plot_pit_qq(pdfs, zgrid, ztrue, title="PIT-QQ - toy data", code="TPZ",
                pit_out_rate=pit_out_rate, savefig=True)
plt.savefig(f'{path}PITQQ.png')

#### 3.5 N(z)

In [None]:
stacker = NaiveStackSummarizer.make_stage(zmin=0.0, zmax=max(photoz_mode), nzbins=301, nsamples=20, hdf5_groupname=None, output=f"{path}Naive_samples.hdf5", single_NZ=f"{path}NaiveStack_NZ.hdf5")

In [None]:
naive_results = stacker.summarize(tpzdata)

In [None]:
#newens = naive_results.data

In [None]:
fig = plt.figure(figsize=(8, 6))
plt.title('Histogram of the True redshift', fontsize = 20)
plt.xlabel('z', fontsize = 17)
plt.grid(color = 'gray', linestyle = '--', linewidth = 0.5) 
z = plt.hist(ztrue, bins=zgrid, density=True, color = 'dodgerblue', )
zmode = plt.hist(photoz_mode, bins=zgrid, density=True, color = 'orange')

In [None]:
cs = UnivariateSpline(zgrid[:-1], z[0])
cs.set_smoothing_factor(0.01)

In [None]:
varinf_nz = qp.read(f"{path}NaiveStack_NZ.hdf5")
varinf_nz.plot_native(xlim=(0,3), color = 'orange', label = 'z$_{phot TPZ}$ PDF')
plt.plot(zgrid,cs(zgrid), color = 'dodgerblue', label = 'z$_{spec}$ PDF')
#plt.plot(zgrid,tpz['data']['yvals'][0], color = 'red', label = 'z$_{photo TPZ}$ PDF')
plt.legend(fontsize = 15)
plt.savefig(f'{path}n(z).png')

#### 3.6 Point Evaluation - necessário apenas para se for fazer a comparação com os requirements

In [None]:
from rail.evaluation.point_to_point_evaluator import PointToPointEvaluator

In [None]:
#ztrue_file_new='/home/andreia.dourado/ic-photoz/andreia_dourado/test_random_photometry.hdf5'
#ztrue_file_new

In [None]:
#ztrue_photometry = DS.read_file('ztrue_photometry', TableHandle, ztrue_file_new)
#len(ztrue_photometry()['photometry']['redshift'])

In [None]:
zmode = photoz_mode.squeeze()
len(zmode)

In [None]:
stage_dict = dict(
    metrics=['point_stats_ez', 'point_stats_iqr', 'point_bias', 'point_outlier_rate', 'point_stats_sigma_mad'],
    _random_state=None,
    hdf5_groupname= 'photometry',
    point_estimate_key='zmode',
    chunk_size=10000,
    metric_config={
        'point_stats_iqr':{'tdigest_compression': 100},
    }
)
ptp_stage = PointToPointEvaluator.make_stage(name='point_to_point', **stage_dict)
ptp_stage_single = PointToPointEvaluator.make_stage(name='point_to_point', force_exact=True, **stage_dict)

In [None]:
import tables_io

In [None]:
ptp_results_single = ptp_stage_single.evaluate(ensemble, ztrue)
results_summary_single = tables_io.convertObj(ptp_stage_single.get_handle('summary')(), tables_io.types.PD_DATAFRAME)
results_summary_single

In [None]:
ez = (zmode-ztrue)/(1+ztrue)
rms = np.sqrt(np.mean((ez-np.mean(ez))**2))

In [None]:
nmad = np.mean(np.abs(ez-np.mean(ez)))*1.4826
nmad

In [None]:
text = f'RMS LSST requirement: 0.05 (goal: 0.02)\
    RMS: {np.round(rms,4)} \
    Bias LSST requirement: 0.003\
    Bias: {np.round(ptp_stage_single.get_handle("summary")()["point_bias"][0],4)}\
    Fraction Outliers LSST requirement: 0.1\
    Fraction Outliters: {np.round(ptp_stage_single.get_handle("summary")()["point_outlier_rate"][0],4)}\
    Sigma MAD: {np.round(ptp_stage_single.get_handle("summary")()["point_stats_sigma_mad"][0],4)}'
print(text)
#print(f'RMS LSST requirement: 0.05 (goal: 0.02)')
#print(f'RMS: {np.round(rms,4)} ')
#print(f'Bias LSST requirement: 0.003')
#print(f'Bias: {np.round(ptp_stage_single.get_handle("summary")()["point_bias"][0],4)}')
#print(f'Fraction Outliers LSST requirement: 0.1')
#print(f'Fraction Outliters: {np.round(ptp_stage_single.get_handle("summary")()["point_outlier_rate"][0],4)}')
#print(f'Sigma MAD: {np.round(ptp_stage_single.get_handle("summary")()["point_stats_sigma_mad"][0],4)}')

##### Salvando as comparações com os requirements em um .txt:

In [None]:
with open(f'{path}requirements_{sigma}sigma.txt', 'w') as file:
    file.write(text)

##### Plots das métricas dos requirements:

In [None]:
plt.scatter(zmode,ez,s=0.01)
plt.axhline(rms,color='red')
plt.axhline(-rms,color='red')
plt.ylim(-0.5,0.5)
plt.ylabel('ez')
plt.xlabel('zphot')
plt.savefig(f'{path}ez_cuts_.png')

In [None]:
z_bins = np.linspace(0.1,3,50)
rms_all=[]

for i in range(len(z_bins[:-1])):
    mask = (zmode>z_bins[i])&(zmode<z_bins[i+1])
    
    delta = zmode[mask] - ztrue[mask]
    ez_i = (delta)/(1+ztrue[mask])
    
    rms_i = np.sqrt(np.mean((ez_i-np.mean(ez_i))**2))
    rms_all.append(rms_i)

plt.scatter(z_bins[:-1],rms_all,s=15,color='blue')
plt.ylim(0,0.15)
plt.axhline(0.02,color='red',ls='--')
plt.axhline(0.05,color='orange',ls='--')
plt.ylabel('RMS',fontsize=16)
plt.xlabel('zphot',fontsize=16)
plt.savefig(f'{path}rms_cuts_.png')

In [None]:
cutcriterion_all = np.maximum(0.06, 3*ptp_stage_single.get_handle('summary')()['point_stats_iqr'][0])
mask = (np.fabs(ez) > np.fabs(cutcriterion_all))
points=np.linspace(0,3.3,1000)

plt.scatter(ztrue[mask],zmode[mask],s=0.1,color='blue')
plt.scatter([],[],color='blue',s=13,label='outliers')
plt.scatter(ztrue[~mask],zmode[~mask],s=0.1,color='black')
plt.scatter(points,points+3*ptp_stage_single.get_handle('summary')()['point_stats_iqr'][0]*(1+points),color='red',s=0.1)
plt.scatter(points,points-3*ptp_stage_single.get_handle('summary')()['point_stats_iqr'][0]*(1+points),color='red',s=0.1)
plt.xlim(0,3.1)
plt.ylim(0,3.1)
plt.legend(fontsize=16,loc=2)
plt.xlabel('ztrue',fontsize=16)
plt.ylabel('zphot',fontsize=16)
plt.savefig(f'{path}outliers_cuts_.png')