In [1]:
import os
import numpy as np 
from scipy import stats
from sedflow import obs as Obs
from sedflow import train as Train

In [2]:
from IPython.display import IFrame
# --- plotting --- 
import corner as DFM
import matplotlib as mpl
import matplotlib.pyplot as plt
#mpl.use('PDF')
#mpl.rcParams['text.usetex'] = True
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['axes.linewidth'] = 1.5
mpl.rcParams['axes.xmargin'] = 1
mpl.rcParams['xtick.labelsize'] = 'x-large'
mpl.rcParams['xtick.major.size'] = 5
mpl.rcParams['xtick.major.width'] = 1.5
mpl.rcParams['ytick.labelsize'] = 'x-large'
mpl.rcParams['ytick.major.size'] = 5
mpl.rcParams['ytick.major.width'] = 1.5
mpl.rcParams['legend.frameon'] = False

In [3]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 
from sbi import utils as Ut
from sbi import inference as Inference

# 1. Validate ANPE using p-p plot and Simulation-Based Calibration

## Load test data

In [4]:
dat_dir = '/scratch/network/chhahn/sedflow/'
dat_dir = '/scratch/network/chhahn/sedflow/'
test_theta      = np.load(os.path.join(dat_dir, 'sedflow_p.test_inobs.v0.1.theta_unt.npy'))
test_enc_spec   = np.load(os.path.join(dat_dir, 'sedflow_p.test_inobs.v0.1.encoded.npy'))
test_enc_ivar   = np.load(os.path.join(dat_dir, 'sedflow_p.test_inobs.v0.1.ivar.encoded.npy'))
test_zred       = np.load(os.path.join(dat_dir, 'sedflow_p.test_inobs.v0.1.zred.npy'))

x_test = test_theta
# convert gamma1, gamma2 (ZH NMF coefficients) to log space
x_test[:,6] = np.log10(x_test[:,6])
x_test[:,7] = np.log10(x_test[:,7])

y_test = np.concatenate([test_enc_spec, test_enc_ivar, test_zred[:,None]], axis=1)

## load samples from `SEDflow` ANPE

In [5]:
arch = '500x5.2'
anpe_samples = np.load('/scratch/network/chhahn/sedflow/anpe.sedflow_p.%s.samples.npy' % arch)

calculate the percentile score and rank of the true values

In [6]:
pp_thetas, rank_thetas = [], [] 
for igal in np.arange(1000): 
    _mcmc_anpe = anpe_samples[igal,:,:]
    
    pp_theta, rank_theta = [], []
    for itheta in range(_mcmc_anpe.shape[1]): 
        pp_theta.append(stats.percentileofscore(_mcmc_anpe[:,itheta], x_test[igal,itheta])/100.)
        rank_theta.append(np.sum(np.array(_mcmc_anpe[:,itheta]) < x_test[igal,itheta]))
    pp_thetas.append(pp_theta)
    rank_thetas.append(rank_theta)
    
pp_thetas = np.array(pp_thetas)
rank_thetas = np.array(rank_thetas)

### p-p plot

In [7]:
theta_lbls = [r'$\log M_*$', r"$\beta'_1$", r"$\beta'_2$", r"$\beta'_3$", r'$f_{\rm burst}$', r'$t_{\rm burst}$', r'$\log \gamma_1$', r'$\log \gamma_2$', r'$\tau_1$', r'$\tau_2$', r'$n_{\rm dust}$']
mpl.use('PDF')
mpl.rcParams['text.usetex'] = True

fig = plt.figure(figsize=(8,8))
sub = fig.add_subplot(111)
for itheta in range(pp_thetas.shape[1]): 
    # evaluate the histogram
    values, base = np.histogram(pp_thetas[:,itheta], bins=40)
    #evaluate the cumulative
    cumulative = np.cumsum(values) / np.sum(values)
    sub.plot(base[:-1], cumulative, label=theta_lbls[itheta])
    
sub.plot([], [], c='gray', ls=':', label='MCMC')
sub.plot([0., 1.], [0., 1.], c='k', ls='--')
sub.set_xlim(0., 1.)
sub.set_ylim(0., 1.)
sub.legend(loc='upper left', fontsize=15)
fig.savefig('../../paper2/figs/ppplot.pdf', bbox_inches='tight')

IFrame("../../paper2/figs/ppplot.pdf", width=600, height=600)

### simulation-based calibration
Metric from Talts+(2020) and uses rank statistic rather than percentile score.

In [9]:
fig = plt.figure(figsize=(15,9))
for i in range(pp_thetas.shape[1]): 
    sub = fig.add_subplot(3,4,i+1)
    sub.hist(rank_thetas[:,i]/10000., density=True, histtype='step', linewidth=2)
    sub.plot([0., 1.], [1., 1.], c='k', ls='--')
    
    sub.text(0.05, 0.95, theta_lbls[i], ha='left', va='top', fontsize=20, transform=sub.transAxes)
    sub.set_xlim(0, 1.)
    sub.set_ylim(0., 2.)
    sub.set_yticklabels([])
    sub.set_xticklabels([])
sub.plot([], [], c='C0', label='ANPE')
sub.plot([], [], c='gray', ls=':', label='MCMC')
sub.legend(loc='lower right', fontsize=15, handletextpad=0)
fig.savefig('../../paper2/figs/sbc.pdf', bbox_inches='tight')

IFrame("../../paper2/figs/sbc.pdf", width=600, height=600)    