In [1]:
import os
import numpy as np 
from scipy import stats
from sedflow import obs as Obs
from sedflow import train as Train

In [2]:
from IPython.display import IFrame
# --- plotting --- 
import corner as DFM
import matplotlib as mpl
import matplotlib.pyplot as plt
#mpl.use('PDF')
#mpl.rcParams['text.usetex'] = True
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['axes.linewidth'] = 1.5
mpl.rcParams['axes.xmargin'] = 1
mpl.rcParams['xtick.labelsize'] = 'x-large'
mpl.rcParams['xtick.major.size'] = 5
mpl.rcParams['xtick.major.width'] = 1.5
mpl.rcParams['ytick.labelsize'] = 'x-large'
mpl.rcParams['ytick.major.size'] = 5
mpl.rcParams['ytick.major.width'] = 1.5
mpl.rcParams['legend.frameon'] = False

In [3]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 
from sbi import utils as Ut
from sbi import inference as Inference

In [4]:
def dirichlet_transform(tt): 
    ''' warped manifold transformation as specified in Betancourt (2013).
    This function transforms samples from a uniform distribution to a
    Dirichlet distribution .

    x_i = (\prod\limits_{k=1}^{i-1} z_k) * f 

    f = 1 - z_i         for i < m
    f = 1               for i = m 

    Parameters
    ----------
    tt : array_like[N,m-1]
        N samples drawn from a (m-1)-dimensional uniform distribution 

    Returns
    -------
    tt_d : array_like[N,m]
        N transformed samples drawn from a m-dimensional dirichlet
        distribution 

    Reference
    ---------
    * Betancourt(2013) - https://arxiv.org/pdf/1010.3436.pdf
    '''
    tt_d = np.empty(tt.shape[:-1]+(tt.shape[1]+1,)) 

    tt_d[...,0] = 1. - tt[...,0]
    for i in range(1,tt.shape[1]): 
        tt_d[...,i] = np.prod(tt[...,:i], axis=-1) * (1. - tt[...,i]) 
    tt_d[...,-1] = np.prod(tt, axis=-1) 
    return tt_d 

# 1. Validate ANPE using p-p plot and Simulation-Based Calibration

## Load test data

In [5]:
# x = theta_sps
# y = [u, g, r, i, z, sigma_u, sigma_g, sigma_r, sigma_i, sigma_z, z]
_x_test, y_test = Train.load_data('test', version=1, sample='toy', params='thetas_unt')

In [6]:
x_test = np.zeros((_x_test.shape[0], _x_test.shape[1]+1))
x_test[:,0] = _x_test[:,0]

# transform back to dirichlet space
x_test[:,1:5] = dirichlet_transform(_x_test[:,1:4])
x_test[:,5:] = _x_test[:,4:]

# log gamma1, gamma2
x_test[:,7] = np.log10(x_test[:,7])
x_test[:,8] = np.log10(x_test[:,8])

## load samples from `SEDflow` ANPE

In [7]:
arch = '500x10.4'
anpe_samples = np.load('/scratch/network/chhahn/sedflow/anpe_thetaunt_magsigz.toy.%s.samples.npy' % arch)

calculate the percentile score and rank of the true values

In [8]:
pp_thetas, rank_thetas = [], [] 
for igal in np.arange(1000): 
    _mcmc_anpe = anpe_samples[igal,:,:]
    _mcmct_anpe = np.zeros((_mcmc_anpe.shape[0], _mcmc_anpe.shape[1]+1))
    _mcmct_anpe[:,0] = _mcmc_anpe[:,0]

    # transform back to dirichlet space
    _mcmct_anpe[:,1:5] = dirichlet_transform(_mcmc_anpe[:,1:4])
    _mcmct_anpe[:,5:] = _mcmc_anpe[:,4:]
    
    pp_theta, rank_theta = [], []
    for itheta in range(_mcmct_anpe.shape[1]): 
        pp_theta.append(stats.percentileofscore(_mcmct_anpe[:,itheta], x_test[igal,itheta])/100.)
        rank_theta.append(np.sum(np.array(_mcmct_anpe[:,itheta]) < x_test[igal,itheta]))
    pp_thetas.append(pp_theta)
    rank_thetas.append(rank_theta)
    
pp_thetas = np.array(pp_thetas)
rank_thetas = np.array(rank_thetas)

calculate percentile score and rank for test MCMC from arcoiris, for reference

In [9]:
dat_dir = '/scratch/network/chhahn/arcoiris/sedflow/'
x_mcmc_test = np.load(os.path.join(dat_dir, 'test.gold.thetas_sps.toy_noise.npy')) 
x_mcmc_test[:,7:9] = np.log10(x_mcmc_test[:,7:9])

_mags = np.load(os.path.join(dat_dir, 'test.gold.mags.toy_noise.npy'))
_sigs = np.load(os.path.join(dat_dir, 'test.gold.sigs.toy_noise.npy'))
_zred = np.load(os.path.join(dat_dir, 'test.gold.zred.toy_noise.npy')) 
y_mcmc_test = np.concatenate([_mags, _sigs, _zred], axis=1)


pp_thetas_mcmc, rank_thetas_mcmc = [], []
for igal in np.arange(100):
    _mcmc_test = np.load(os.path.join('/scratch/network/chhahn/arcoiris/sedflow/mcmc_test_redo/', 'mcmc.test.toy.gold.%i.npy' % igal))
    _mcmc_test = Train.flatten_chain(_mcmc_test[2000:])
    
    _mcmct_test = np.zeros((_mcmc_test.shape[0], _mcmc_test.shape[1]+1))
    _mcmct_test[:,0] = _mcmc_test[:,0]

    # transform back to dirichlet space
    _mcmct_test[:,1:5] = dirichlet_transform(_mcmc_test[:,1:4])
    _mcmct_test[:,5:] = _mcmc_test[:,4:]
    
    pp_theta_mcmc, rank_theta_mcmc = [], []
    for itheta in range(_mcmct_test.shape[1]):
        pp_theta_mcmc.append(stats.percentileofscore(_mcmct_test[:,itheta], x_mcmc_test[igal,itheta])/100.)
        rank_theta_mcmc.append(np.sum(np.array(_mcmct_test[:,itheta]) < x_mcmc_test[igal,itheta])/float(_mcmct_test.shape[0]))
    pp_thetas_mcmc.append(pp_theta_mcmc)
    rank_thetas_mcmc.append(rank_theta_mcmc)

pp_thetas_mcmc = np.array(pp_thetas_mcmc)
rank_thetas_mcmc = np.array(rank_thetas_mcmc)    

### p-p plot

In [10]:
theta_lbls = [r'$\log M_*$', r"$\beta_1$", r"$\beta_2$", r"$\beta_3$", r"$\beta_4$", r'$f_{\rm burst}$', r'$t_{\rm burst}$', r'$\log \gamma_1$', r'$\log \gamma_2$', r'$\tau_1$', r'$\tau_2$', r'$n_{\rm dust}$']
mpl.use('PDF')
mpl.rcParams['text.usetex'] = True

fig = plt.figure(figsize=(8,8))
sub = fig.add_subplot(111)
for itheta in range(pp_thetas.shape[1]): 
    # evaluate the histogram
    values, base = np.histogram(pp_thetas[:,itheta], bins=40)
    #evaluate the cumulative
    cumulative = np.cumsum(values) / np.sum(values)
    sub.plot(base[:-1], cumulative, label=theta_lbls[itheta])
    
    # evaluate the histogram
    values, base = np.histogram(pp_thetas_mcmc[:,itheta], bins=40)
    #evaluate the cumulative
    cumulative = np.cumsum(values) / np.sum(values)
    sub.plot(base[:-1], cumulative, c='gray', lw=1, ls=':')
sub.plot([], [], c='gray', ls=':', label='MCMC')
sub.plot([0., 1.], [0., 1.], c='k', ls='--')
sub.set_xlim(0., 1.)
sub.set_ylim(0., 1.)
sub.legend(loc='upper left', fontsize=15)
fig.savefig('paper/figs/ppplot.pdf', bbox_inches='tight')

IFrame("paper/figs/ppplot.pdf", width=600, height=600)

### simulation-based calibration
Metric from Talts+(2020) and uses rank statistic rather than percentile score.

In [11]:
fig = plt.figure(figsize=(15,9))
for i in range(pp_thetas.shape[1]): 
    sub = fig.add_subplot(3,4,i+1)
    sub.hist(rank_thetas[:,i]/10000., density=True, histtype='step', linewidth=2)
    sub.hist(rank_thetas_mcmc[:,i], density=True, histtype='step', linewidth=0.75, color='gray', linestyle=':')
    sub.plot([0., 1.], [1., 1.], c='k', ls='--')
    
    sub.text(0.05, 0.95, theta_lbls[i], ha='left', va='top', fontsize=20, transform=sub.transAxes)
    sub.set_xlim(0, 1.)
    sub.set_ylim(0., 2.)
    sub.set_yticklabels([])
    sub.set_xticklabels([])
sub.plot([], [], c='C0', label='ANPE')
sub.plot([], [], c='gray', ls=':', label='MCMC')
sub.legend(loc='lower right', fontsize=15, handletextpad=0)
fig.savefig('paper/figs/sbc.pdf', bbox_inches='tight')

IFrame("paper/figs/sbc.pdf", width=600, height=600)    

# Comparison of ANPE vs MCMC posterior for single galaxy

In [12]:
prior_low = [7, 0., 0., 0., 0., 1e-2, np.log10(4.5e-5), np.log10(4.5e-5), 0, 0., -2.]
prior_high = [12.5, 1., 1., 1., 1., 13.27, np.log10(1.5e-2), np.log10(1.5e-2), 3., 3., 1.]
lower_bounds = torch.tensor(prior_low)
upper_bounds = torch.tensor(prior_high)

prior = Ut.BoxUniform(low=lower_bounds, high=upper_bounds, device='cpu')

In [13]:
fanpe = '/scratch/network/chhahn/sedflow/anpe_thetaunt_magsigz.toy.%s.pt' % arch

anpe = Inference.SNPE(prior=prior, density_estimator=Ut.posterior_nn('maf', hidden_features=500, num_transforms=10), device='cpu')
anpe.append_simulations(
    torch.as_tensor(_x_test.astype(np.float32)), 
    torch.as_tensor(y_test.astype(np.float32)))

p_x_y_estimator = anpe._build_neural_net(torch.as_tensor(_x_test.astype(np.float32)), torch.as_tensor(y_test.astype(np.float32)))
p_x_y_estimator.load_state_dict(torch.load(fanpe))

anpe._x_shape = Ut.x_shape_from_simulation(torch.as_tensor(y_test.astype(np.float32)))

In [14]:
hatp_x_y = anpe.build_posterior(p_x_y_estimator)

# Load NSA data and MCMC posteriors

In [15]:
nsa, _ = Obs.NSA()

flux_nsa = nsa['NMGY'][:,2:]
ivar_nsa = nsa['NMGY_IVAR'][:,2:]

zred_nsa = nsa['Z']

mags_nsa = Train.flux2mag(flux_nsa)
sigs_nsa = Train.sigma_flux2mag(ivar_nsa**-0.5, flux_nsa)

y_nsa = np.concatenate([mags_nsa, sigs_nsa, zred_nsa[:,None]], axis=1)

  return 22.5 - 2.5 * np.log10(flux)


In [16]:
i_nsa = 25

In [17]:
c_light = 2.998e18
jansky_cgs = 1e-23

lambda_sdss = np.array([3543., 4770., 6231., 7625., 9134.])

flux_nsa_conv = flux_nsa[i_nsa] * 1e-9 * 1e17 * c_light / lambda_sdss**2 * (3631. * jansky_cgs) # convert to 10^-17 ergs/s/cm^2/Ang
ivar_nsa_conv = ivar_nsa[i_nsa] * (1e-9 * 1e17 * c_light / lambda_sdss**2 * (3631. * jansky_cgs))**-2.


In [18]:
dir_mcmc = '/scratch/network/chhahn/arcoiris/sedflow/mcmc_nsa_redo/'

_anpe = np.array(hatp_x_y.sample((10000,), x=torch.as_tensor(y_nsa[i_nsa,:]), show_progress_bars=True))
anpe = np.zeros((_anpe.shape[0], _anpe.shape[1]+1))
anpe[:,0] = _anpe[:,0]
# transform back to dirichlet space
anpe[:,1:5] = dirichlet_transform(_anpe[:,1:4])
anpe[:,5:] = _anpe[:,4:]    

_mcmc = np.load(os.path.join(dir_mcmc, 'mcmc.nsa.%i.npy' % i_nsa))
_mcmc = Train.flatten_chain(_mcmc[4000:])
mcmc = np.zeros((_mcmc.shape[0], _mcmc.shape[1]+1))
mcmc[:,0] = _mcmc[:,0]
# transform back to dirichlet space
mcmc[:,1:5] = dirichlet_transform(_mcmc[:,1:4])
mcmc[:,5:] = _mcmc[:,4:]    

Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

In [19]:
from provabgs import models as Models

nsa_filters = Train.photometry_bands()

m_sps = Models.NMF(burst=True, emulator=True)



input parameters : logmstar, beta1_sfh, beta2_sfh, beta3_sfh, beta4_sfh, fburst, tburst, gamma1_zh, gamma2_zh, dust1, dust2, dust_index


In [20]:
log_prob_anpe = hatp_x_y.log_prob(_anpe, x=torch.as_tensor(y_nsa[i_nsa,:]))
i_anpe_bf = np.argmax(log_prob_anpe)

In [21]:
# get best fit
_tt_bf = anpe[i_anpe_bf].copy()
_tt_bf[7:9] = 10**_tt_bf[7:9]
w_anpe, f_anpe = m_sps.sed(_tt_bf, zred_nsa[i_nsa])
maggies_anpe = Train.SED_to_maggies(w_anpe, f_anpe, filters=nsa_filters)
mags_anpe = Train.flux2mag(np.array(list(maggies_anpe.as_array()[0])) * 1e9)

In [22]:
fspecs_mcmc, fluxes_mcmc = [], []
for _tt in mcmc[::100]: 
    tt = _tt.copy()
    tt[7:9] = 10**tt[7:9]
    w_mcmc, f_mcmc = m_sps.sed(tt, zred_nsa[i_nsa])
    _maggies = Train.SED_to_maggies(w_mcmc, f_mcmc, filters=nsa_filters)
    fspecs_mcmc.append(f_mcmc)
    fluxes_mcmc.append(np.array(_maggies.as_array()[0].tolist()) * 1e9)

dflux = flux_nsa[i_nsa] - np.array(fluxes_mcmc)
log_prob_mcmc = -0.5 * np.sum(dflux**2 * ivar_nsa[i_nsa])  

mags_mcmc = Train.flux2mag(fluxes_mcmc[np.argmax(log_prob_mcmc)])
f_mcmc = fspecs_mcmc[np.argmax(log_prob_mcmc)]

In [23]:
print(mags_nsa[i_nsa])
print(mags_anpe)
print(mags_mcmc)

[18.794628 17.801994 17.454891 17.212591 17.092417]
[18.83138339 17.81449052 17.45051507 17.24406462 17.07730637]
[18.81137748 17.81960621 17.43447054 17.23380366 17.08382311]


In [24]:
_low = [7, 0., 0., 0., 0., 0., 1e-2, np.log10(4.5e-5), np.log10(4.5e-5), 0, 0., -2.]
_high = [12.5, 1., 1., 1., 1., 1., 13.27, np.log10(1.5e-2), np.log10(1.5e-2), 3., 3., 1.]
theta_range = [(l, h) for l, h in zip(_low, _high)]

In [30]:
ndim = len(theta_lbls)

fig = plt.figure(figsize=(15, 18))

gs0 = fig.add_gridspec(nrows=ndim, ncols=ndim, top=0.95, bottom=0.26)
for yi in range(ndim):
    for xi in range(ndim):
        sub = fig.add_subplot(gs0[yi, xi])

_fig = DFM.corner(mcmc, color='k', levels=[0.68, 0.95], range=theta_range,
                  plot_density=False, plot_datapoints=False, hist_kwargs={'density': True}, fig=fig)
_ = DFM.corner(anpe, color='C0', levels=[0.68, 0.95], range=theta_range,
               plot_density=False, plot_datapoints=False, hist_kwargs={'density': True}, 
               labels=theta_lbls, label_kwargs={'fontsize': 25}, fig=fig)

# ndim = int(np.sqrt(len(fig.axes)))
axes = np.array(fig.axes).reshape((ndim, ndim))

ax = axes[2, ndim-4]
ax.fill_between([], [], [], color='k', label='MCMC posterior')
ax.fill_between([], [], [], color='C0', label='ANPE posterior')
ax.legend(handletextpad=0.2, markerscale=10, fontsize=25)

axes = np.array(fig.axes).reshape((ndim, ndim))
for yi in range(1, ndim):
    ax = axes[yi, 0]
    ax.set_ylabel(theta_lbls[yi], fontsize=20, labelpad=30)
    ax.yaxis.set_label_coords(-0.6, 0.5)
for xi in range(ndim): 
    ax = axes[-1, xi]
    ax.set_xlabel(theta_lbls[xi], fontsize=20, labelpad=30)
    ax.xaxis.set_label_coords(0.5, -0.55)
for xi in range(ndim): 
    ax = axes[xi,xi]
    ax.set_xlim(theta_range[xi])

gs1 = fig.add_gridspec(nrows=1, ncols=1, top=0.2, bottom=0.05)
sub = fig.add_subplot(gs1[0,0])

sub.plot(w_anpe, f_anpe, c='C0', lw=1, label='ANPE best-fit')
sub.plot(w_mcmc, f_mcmc, c='k', lw=1, ls=':', label='MCMC best-fit')
sub.errorbar(lambda_sdss, flux_nsa_conv, yerr=ivar_nsa_conv**-0.5, fmt='.C3', markersize=6, elinewidth=2, label='NSA Photometry')
sub.legend(loc='upper right', fontsize=18, handletextpad=0)
sub.set_xlabel('wavelength [$A$]', fontsize=20) 
sub.set_xlim(3e3, 1e4)
sub.set_ylabel('flux [$10^{-17} erg/s/cm^2/A$]', fontsize=20, labelpad=15) 
sub.set_ylim(0., 60.)

fig.savefig('paper/figs/corner.pdf', bbox_inches='tight')
IFrame("paper/figs/corner.pdf", width=600, height=600)