# Results analysis

_Adapted from Sam Schmidt example notebook (https://github.com/LSSTDESC/rail_tpz) and Demo: RAIL Evaluation notebook (https://rail-hub.readthedocs.io/projects/rail-notebooks/en/latest/rendered/evaluation_examples/Evaluation_Demo.html)._

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import rail
import qp
from rail.core.data import TableHandle, PqHandle, ModelHandle, QPHandle, DataHandle, Hdf5Handle
from rail.core.data import TableHandle
from rail.core.stage import RailStage

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
from qp.ensemble import Ensemble
from matplotlib import gridspec
from qp import interp
from qp.metrics.pit import PIT
from rail.evaluation.metrics.cdeloss import *
from rail.evaluation.evaluator import OldEvaluator
from utils import plot_pit_qq, ks_plot
import os


%matplotlib inline
%reload_ext autoreload
%autoreload 2

### 1. Reading the data

In [None]:
ztrue_file='/home/andreia.dourado/data/test_data.hdf5'

In [None]:
ztrue_data = DS.read_file('ztrue_data', TableHandle, ztrue_file)

In [None]:
pdfs_file='/home/andreia.dourado/data/output.hdf5'

In [None]:
tpzdata = DS.read_file('pdfs_data', QPHandle, pdfs_file)

In [None]:
ztrue = ztrue_data()['redshift']
zgrid = tpzdata().metadata()['xvals'].ravel()
photoz_mode = tpzdata().mode(grid=zgrid)

In [None]:
truth = DS.add_data('truth', ztrue_data(), TableHandle)
ensemble = DS.add_data('ensemble', tpzdata(), QPHandle)

In [None]:
DS.keys()

### 2. Point estimates

In [None]:
sz = ztrue_data()['redshift']
zmode = tpzdata().ancil['zmode']

In [None]:
plt.figure(figsize=(8,8))
plt.scatter(sz,zmode, s=2,c='k')
plt.plot([0,3],[0,3],'r--')
plt.xlabel("redshift", fontsize=15)
plt.ylabel("TPZ mode", fontsize=15)
plt.savefig('point_estimate.png')

### 2. A individual redshift PDF

In [None]:
which=1355
fig, axs = plt.subplots()
tpzdata().plot_native(key=which,axes=axs, label=f"PDF for galaxy {which}")
axs.axvline(sz[which],c='r',ls='--', label="true redshift")
plt.legend(loc='upper right', fontsize=12)
axs.set_xlabel("redshift")
plt.savefig('example_pdf.png')

### 3. Basic metrics

In [None]:
def plot_metrics(zspec,
                 zphot,
                 maximum,
                 path_to_save='',
                 title=None,
                 initial=0):
    
    '''
    Function to plot Bias, Sigma_68, Out2σ, Out3σ given a spectroscopic and photometric redshift. 
    
    Args:
    
    zspec: Numpy array with the spectroscopic redshift.
    
    zphot: Numpy array with the photometric redshifts calculated. Same size as zspec.
    
    maximum: Float that indicates the redshift max of the plots.
    
    Kwargs:
    
    initial: Float that indicates the redshift min of the plots.
    
    
    
    '''
    
    
    
    
    
    
    bins = np.arange(initial, maximum, 0.1)
    points = bins+0.05
    fraction_outliers = []
    sigma68z = []
    sigmaz=[]
    meanz = []
    outliers_2 = []

    for index in range(len(bins) - 1):
        bin_lower = bins[index]
        bin_upper = bins[index + 1]
        
    
        values_r = zphot[(zphot >= bin_lower) & (zphot <= bin_upper)]
        values_s = zspec[(zphot >= bin_lower) & (zphot <= bin_upper)]
        
        

        deltabias = (values_r - values_s)
        mean_bias = np.mean(deltabias)  # Mean bias for each bin
        meanz.append(mean_bias)
        
        
    
        s = np.sort(np.abs(deltabias/(1+values_s)))  # Standard deviation (sigma) for each bin
        sigma68 = s[int(len(s)*0.68)]
        sigma68z.append(sigma68)
        
        
        
        sigma = (np.sum((values_r-values_s-mean_bias)**2)/len(values_r))**0.5
        sigmaz.append(sigma)
    
        # Calculate the fraction of outliers outside 3 sigma
        outliers = deltabias[np.abs(deltabias-mean_bias) > 3 * sigma]
        fraction_outlier = len(outliers) / len(deltabias)
        fraction_outliers.append(fraction_outlier)
        
        
    
        #2 sigma
        outliers2 = deltabias[np.abs(deltabias-mean_bias) > 2 * sigma]
        fraction_outlier2 = len(outliers2) / len(deltabias)
        outliers_2.append(fraction_outlier2)



    fig, axes = plt.subplots(4, 1, figsize=(8, 14), sharex=True)
    plt.subplots_adjust(hspace=0.001) 


    x_lim = (0, np.max(bins))

    # Subplot 1: Mean Bias
    axes[0].plot(points[:-1], meanz, 'bo-')
    axes[0].axhline(0, color='black', linestyle='--')
    axes[0].set_ylabel(r'$\Delta z$', fontsize=20)
    axes[0].set_xlim(x_lim)
    axes[0].set_ylim(-0.05,0.05)
    axes[0].tick_params(axis='both', labelsize=14)
    axes[0].grid(True)

    # Subplot 2: Sigma 68
    axes[1].plot(points[:-1], sigma68z, 'go-')
    axes[1].set_ylabel(r'$\sigma_{68}$', fontsize=20)
    axes[1].set_xlim(x_lim)
    axes[1].axhline(0.12, color='black', linestyle='--')
    axes[1].set_ylim(0,max(sigmaz)+0.01)
    axes[1].set_ylim(0, 0.03)
    axes[1].set_yticks(np.arange(0, 0.5, 0.05))
    axes[1].tick_params(axis='both', labelsize=14)
    axes[1].grid(True)


    # Subplot3: 2_outliers
    axes[2].plot(points[:-1],outliers_2,'o-',color='darkorange')
    #axes[2].set_xlabel(r'$Z_{phot}$', fontsize=20)
    axes[2].set_ylabel('out$_{2σ}$', fontsize=20)
    axes[2].set_xlim(x_lim)
    axes[2].set_ylim(0,0.12)
    axes[2].axhline(0.1, color='black', linestyle='--')
    axes[2].tick_params(axis='both', labelsize=14)
    axes[2].grid(True)
    #axes[2].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

    # Subplot 4: 3_outliers
    axes[3].plot(points[:-1], fraction_outliers, 'ro-')
    axes[3].set_xlabel(r'$Z_{spec}$', fontsize=20)
    axes[3].set_ylabel('out$_{3σ}$', fontsize=20)
    axes[3].set_xlim(x_lim)
    axes[3].set_ylim(0,0.02)
    axes[3].axhline(0.015, color='black', linestyle='--')
    axes[3].tick_params(axis='both', labelsize=14)
    axes[3].grid(True)
   

    plt.suptitle(title)
    plt.xlim(0,3)
    plt.tight_layout()

    #if path_to_save != '':
    plt.savefig('metrics.png')
    

    plt.show()

In [None]:
zmode= np.squeeze(zmode)
plot_metrics(sz,zmode,max(zmode)-0.2,initial=0)

### 4. RAIL Evaluation 

In [None]:
TPZ_eval = OldEvaluator.make_stage(name='TPZ_eval', truth=truth)

In [None]:
TPZ_results = TPZ_eval.evaluate(ensemble(), truth)

In [None]:
import tables_io
results_df= tables_io.convertObj(TPZ_results(), tables_io.types.PD_DATAFRAME)
results_df

In [None]:
pitobj = PIT(tpzdata(), ztrue)
quant_ens = pitobj.pit
metamets = pitobj.calculate_pit_meta_metrics()

In [None]:
metamets

In [None]:
pit_vals = np.array(pitobj.pit_samps)
pit_vals

In [None]:
pit_out_rate = metamets['outlier_rate']
print(f"PIT outlier rate of this sample: {pit_out_rate:.6f}")
pit_out_rate = pitobj.evaluate_PIT_outlier_rate()
print(f"PIT outlier rate of this sample: {pit_out_rate:.6f}")

In [None]:
pdfs = tpzdata.data.objdata()['yvals']
plot_pit_qq(pdfs, zgrid, ztrue, title="PIT-QQ - toy data", code="TPZ",
                pit_out_rate=pit_out_rate, savefig=False)

In [None]:
ks_stat_and_pval = metamets['ks']
print(f"PIT KS stat and pval: {ks_stat_and_pval}")
ks_stat_and_pval = pitobj.evaluate_PIT_KS()
print(f"PIT KS stat and pval: {ks_stat_and_pval}")

In [None]:
ks_plot(pitobj)

In [None]:
print(f"KS metric of this sample: {ks_stat_and_pval.statistic:.4f}")