In [1]:
from collections import OrderedDict
from functools import partial
from itertools import product

import numpy as np
import pandas as pd
import xarray as xa
from matplotlib import pyplot as plt
from scipy.stats import gaussian_kde
from tqdm.auto import tqdm

from libplotting import to_multidict
from libsimplesn import SimpleSN


dims = OrderedDict((
    ('N', (1_000, 2_000, 5_000, 10_000, 20_000, 50_000, 100_000)),
    ('datatype', (
        # 'specz', 'mphotoz',
        'photoz',
    )),
    ('suffix', list(range(1))),
    ('version', list(range(10)))
))
args = list(product(*dims.values()))


def dewarmup(N, datatype):
    return {'draw': slice(500 if datatype == 'photoz' else 200, None)}

SURVEY = 'pantheon-g10'
ssn = SimpleSN(SURVEY)

In [2]:
from multiprocessing import Pool


def hdi2d(data: xa.Dataset, names: tuple[str, str]):
    x1, x2 = (data[n].to_numpy().flatten() for n in names)
    asort = np.argsort(gaussian_kde((x1, x2))((x1, x2)))
    x1, x2 = x1[asort], x2[asort]
    return x1, x2, np.linspace(1, 0, len(x1)+1)[1:]

def do(N, datatype, suffix, version):
    return (N, datatype, suffix, version), hdi2d(
        SimpleSN(survey=SURVEY, N=N, suffix=suffix, datatype=datatype, version=version).emcee_result.to_dataset()[dewarmup(N, datatype)],
        ('Om0', 'Ode0')
    )

def starmapper(func, args):
    return func(*args)


ssn.hdi = dict(tqdm(Pool().imap_unordered(partial(starmapper, do), args), total=len(args)))

  0%|          | 0/70 [00:00<?, ?it/s]

In [2]:
hdis = to_multidict(pd.Series(ssn.hdi))

In [7]:
NSIGMA = 5.

ssn.hdi_bounds = pd.DataFrame(_ := {
    (datatype, N, key): dict(zip(('lower', 'upper'), val.to_numpy().tolist()))
    for datatype in dims['datatype']
    for val in [
        xa.concat((
            xa.concat((
                xa.concat((
                    SimpleSN(survey=SURVEY, N=N, suffix=suffix, datatype=datatype, version=version).emcee_result.to_dataset()
                    for version in dims['version']
                ), dim=pd.Index(dims['version'], name='version'))
                for suffix in dims['suffix']
            ), dim=pd.Index(dims['suffix'], name='suffix'))
            for N in dims['N']
        ), dim=pd.Index(dims['N'], name='N'))[dewarmup(N, datatype)]
    ]
    for m in [val.mean(('chain', 'draw'))]
    for s in [val.std(('chain', 'draw'))]
    for bds in [xa.concat(((m-NSIGMA*s).min(('suffix', 'version')), (m+NSIGMA*s).max(('suffix', 'version'))), pd.Index(['lower', 'upper'], name='bound'))]
    for N in dims['N']
    for key, val in bds.loc[{'N': N, 'bound': ['lower', 'upper']}].items()
}, columns=pd.MultiIndex.from_tuples(_.keys(), names=('datatype', 'N', 'param')))

In [None]:
# bounds = {
#     'mphotoz': {
#         1_000: ((0, 0.7), (0, 1.5)),
#         2_000: ((0, 0.7), (0, 1.5)),
#         5_000: ((0., 0.6), (0., 1.2)),
#         10_000: ((0.1, 0.5), (0.3, 1.1)),
#         20_000: ((0.15, 0.5), (0.4, 1.1)),
#         50_000: ((0.2, 0.4), (0.5, 0.9)),
#         100_000: ((0.25, 0.35), (0.6, 0.8)),
#     }
# }

In [23]:
from matplotlib.collections import LineCollection

bounds = ssn.hdi_bounds

for N, _bds in ((N, bounds['mphotoz', N]) for N in bounds['mphotoz'].columns.get_level_values('N')):
    plt.figure()
    plt.gca().add_collection(LineCollection((((0, 0.1), (1, 2.1)),), color='k'), autolim=False)

    for hdi in hdis[N]['mphotoz'].values():
        for hdi in hdi.values():
            plt.tricontour(*hdi, levels=(0.9,), colors='g', alpha=0.7)

    for hdi in hdis[N]['specz'].values():
        for hdi in hdi.values():
            plt.tricontour(*hdi, levels=(0.9,), colors='r', alpha=0.7)

    plt.plot(0.3, 0.7, 'ko')
    plt.gca().set_aspect('equal')
    plt.xlim(_bds['Om0'])
    plt.ylim(_bds['Ode0'])
    plt.xlabel(r'$\Omega_{m, 0}$'); plt.ylabel(r'$\Omega_{\Lambda, 0}$')

    plt.legend([plt.Line2D([], [], color='g'), plt.Line2D([], [], color='r')], ['photoz', 'specz'], loc='upper left')
    plt.text(0.98, 0.02, f'$N={N}$', va='bottom', ha='right', transform=plt.gca().transAxes)
    plt.savefig(ssn.hdidir / f'pantheon-g10-{N}.png', bbox_inches='tight')
    plt.close()

  for N, _bds in ((N, bounds['mphotoz', N]) for N in bounds['mphotoz'].columns.get_level_values('N')):
