In [1]:
from operator import itemgetter

import pyro
import torch
from matplotlib import pyplot as plt

# noinspection PyUnresolvedReferences
from clipppy.patches import torch_numpy
from libsimplesn import SimpleSN

### NRE

In [None]:
simplesn = SimpleSN(survey='pantheon-g10', datatype='specz', N=100_000, suffix=0, version=0)
config = simplesn.config('simplesn.yaml', gen=True, latent=True)
nre = config.lightning_nre

LATENT_PARAMS = 'M0',

In [None]:
from libplotting import get_priors

nre.dataset_config.kwargs['ranges'].update({
    key: itemgetter('lower', 'upper')(val)
    for key, val in simplesn.hdi_bounds[(simplesn.datatype, simplesn.N)].to_dict().items()
})
priors = get_priors(set(nre.param_names) - set(LATENT_PARAMS), nre.dataset.dataset)
ranges = {key: (prior.support.lower_bound, prior.support.upper_bound) for key, prior in priors.items()}
ranges

In [None]:
from clipppy.stochastic import find_sampler
from clipppy.utils.messengers import CollectSitesMessenger


with CollectSitesMessenger(*LATENT_PARAMS) as prior_samples, pyro.plate('plate', 10000), nre.dataset.dataset.context:
    find_sampler(config._model, 'M0')()
prior_samples = {key: prior_samples[key]['value'] for key in LATENT_PARAMS}

plt.hist(prior_samples['M0'][:, 0].numpy(), bins=100, density=True)
plt.plot(_:=torch.linspace(-20, -19, 101), torch.distributions.Normal(-19.5, 0.1).log_prob(_).exp_());

### MCMC

In [3]:
for N in (1000, 2000, 5000, 10_000, 20_000, 50_000, 100_000):
    for datatype in ('mphotoz', 'specz'):
        for suffix in (0,):
            for version in range(10):
                simplesn = SimpleSN(survey='pantheon-g10', datatype=datatype, N=N, suffix=suffix, version=version)
                config = simplesn.config('simplesn-marginal.yaml', gen=False)
                data = simplesn.data

                loc, var = simplesn.emcee_latent

                fig, axs = plt.subplots(1, 2, sharey='row', gridspec_kw=dict(width_ratios=(2, 1)), figsize=(8, 4))


                isspecz = simplesn.datatype == 'specz'

                distmod = data['distmod'] if 'distmod' in data else data['m'] - data['M']
                R_z = data.get('R_z', 1)

                truth = distmod + data['mean_M0']

                y0 = (loc[:, 0] - truth) / R_z
                y1 = (loc[:, 1] - truth) / R_z

                # plt.plot(data['z'][asort], truth[asort], color='green')
                axs[0].errorbar(truth, y0, yerr=var[:, 0, 0]**0.5, ls='none', color='green')
                axs[0].errorbar(truth, y1, yerr=var[:, 1, 1]**0.5, ls='none', color='red')


                bins = 50 if isspecz else torch.linspace(-5, 5, 51)
                histkwargs = dict(bins=bins, density=True, histtype='step', orientation='horizontal')

                labelfmt = r'$\bar{%s}^s - M_0$' if isspecz else r'$(\bar{%s}^s - M_0) / R_z$'

                axs[1].hist(y0.numpy(), **histkwargs, color='green', label=labelfmt % 'M')
                axs[1].hist(y1.numpy(), **histkwargs, color='red', label=labelfmt % 'm')

                if not isspecz:
                    axs[1].plot(torch.distributions.Normal(0, 1).log_prob(bins).exp_(), bins,
                             label='$\mathcal{G}(0, 1)$')
                axs[1].legend()

                axs[1].set_xticklabels([])
                axs[0].set_xlabel(r'$m$')
                fig.suptitle(simplesn.data_prefix)

                name = simplesn.plotdir / f'{simplesn.data_prefix}-emcee-latent.png'
                print(name)
                fig.savefig(name)
                plt.close(fig)

res/pantheon-g10/pantheon-g10-1000-0/plots/pantheon-g10-1000-0-mphotoz-0-emcee-latent.png
res/pantheon-g10/pantheon-g10-1000-0/plots/pantheon-g10-1000-0-mphotoz-1-emcee-latent.png
res/pantheon-g10/pantheon-g10-1000-0/plots/pantheon-g10-1000-0-mphotoz-2-emcee-latent.png
res/pantheon-g10/pantheon-g10-1000-0/plots/pantheon-g10-1000-0-mphotoz-3-emcee-latent.png
res/pantheon-g10/pantheon-g10-1000-0/plots/pantheon-g10-1000-0-mphotoz-4-emcee-latent.png
res/pantheon-g10/pantheon-g10-1000-0/plots/pantheon-g10-1000-0-mphotoz-5-emcee-latent.png
res/pantheon-g10/pantheon-g10-1000-0/plots/pantheon-g10-1000-0-mphotoz-6-emcee-latent.png
res/pantheon-g10/pantheon-g10-1000-0/plots/pantheon-g10-1000-0-mphotoz-7-emcee-latent.png
res/pantheon-g10/pantheon-g10-1000-0/plots/pantheon-g10-1000-0-mphotoz-8-emcee-latent.png
res/pantheon-g10/pantheon-g10-1000-0/plots/pantheon-g10-1000-0-mphotoz-9-emcee-latent.png
res/pantheon-g10/pantheon-g10-1000-0/plots/pantheon-g10-1000-0-specz-0-emcee-latent.png
res/pantheon