In [None]:
from configparser import ConfigParser
import copy
import sys
import glob
import os
import json
import re
import pandas as pd
import itertools

import bilby
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.stats import norm
from scipy.special import logsumexp
import seaborn as sns

np.random.seed(1234)

basedir = "../../"

sys.path.append(basedir)

from utils import configure_plotting, natural_sort
configure_plotting(basedir)

linewidth = 6.17804


# Disable the bilby plotting style
os.environ["BILBY_STYLE"] = "none"

from js import calculate_js

In [None]:
parser = ConfigParser()
with open("config.mk") as stream:
    parser.read_string("[top]\n" + stream.read())
config = parser["top"]

In [None]:
result_paths = {
    'dynesty': os.path.join(config["dynesty_outdir"], ""),
    'nessai': os.path.join(config["nessai_outdir"], ""),
    'inessai': os.path.join(config["inessai_outdir"], ""),
}
print(f"Using the following result paths: {result_paths}")

In [None]:
results = {}
for s, rp in result_paths.items():
    files = natural_sort(glob.glob(os.path.join(rp, "result", "*par*.hdf5")))
    res = []
    for rf in files:
        res.append(bilby.core.result.read_in_result(rf))
    results[s] = res
results_list = list(results.values())
samplers = list(results.keys())

In [None]:
merge_results = {}
for s, rp in result_paths.items():
    file = natural_sort(glob.glob(os.path.join(rp, "result", "*merge*.hdf5")))
    merge_results[s] = bilby.core.result.read_in_result(file[0])
merge_results_list = list(merge_results.values())
samplers = list(results.keys())

# Table results

In [None]:
sampler_labels = {
    "dynesty": r"\codestyle{dynesty}",
    "nessai": r"\codestyle{nessai}",
    "inessai": r"\codestyle{i-nessai}",
}

In [None]:
field_labels = {
    "wall_time": "Wall time [min]",
    "n_likelihood_evaluations": "Likelihood evaluations",
    "ess": "Effective sample size",
}

In [None]:
table_results = {}
for sampler, sampler_results in results.items():
    table_results[sampler] = dict(
        wall_time=[],
        n_likelihood_evaluations=[],
        ess=[],
    )
    for res in sampler_results:
        table_results[sampler]["wall_time"].append(res.sampling_time.total_seconds() / 60)
        table_results[sampler]["n_likelihood_evaluations"].append(res.num_likelihood_evaluations)
        log_w = np.log(res.nested_samples['weights'])
        ess = np.exp(2 * logsumexp(log_w)-logsumexp(2 * log_w))
        table_results[sampler]["ess"].append(ess)
        

In [None]:
table_results

In [None]:
summary = {}
for sampler, tr in table_results.items():
    sk = sampler_labels.get(sampler)
    summary[sk] = {}
    summary[sk][field_labels.get("wall_time")] = \
        f"${np.mean(tr['wall_time']):.1f}$" 
    summary[sk][field_labels.get("n_likelihood_evaluations")] = \
        r"$\num{" + f"{np.mean(tr['n_likelihood_evaluations']):.2e}" + "}$"
    summary[sk][field_labels.get("ess")] = \
        f"${np.mean(tr['ess']):.0f}$" 

In [None]:
for sampler in ["dynesty", "nessai"]:
    print(sampler)
    print("likelihood ratio", np.mean(table_results[sampler]["n_likelihood_evaluations"]) / np.mean(table_results["inessai"]["n_likelihood_evaluations"]))
    print("time ratio ",np.mean(table_results[sampler]["wall_time"]) / np.mean(table_results["inessai"]["wall_time"]))

In [None]:
df = pd.DataFrame(summary).T
print(df)

In [None]:
with open("results/bns_comparison_table_one_run.tex", "w") as fp:
    fp.write(df.style.to_latex(hrules=True, column_format="lccc"))

In [None]:
summary = {}
for sampler, tr in table_results.items():
    sk = sampler_labels.get(sampler)
    summary[sk] = {}
    summary[sk][field_labels.get("wall_time")] = \
        f"${np.mean(tr['wall_time']):.1f} \pm {np.std(tr['wall_time']):.1f}$" 
    summary[sk][field_labels.get("n_likelihood_evaluations")] = \
        r"$\num{" + f"{np.mean(tr['n_likelihood_evaluations']):.2e}" + r"}\pm\num{" + f"{np.std(tr['n_likelihood_evaluations']):.2e}" + r"}$" 
    summary[sk][field_labels.get("ess")] = \
        f"${np.mean(tr['ess']):.0f} \pm {np.std(tr['ess']):.0f}$" 

In [None]:
df = pd.DataFrame(summary).T
print(df)

In [None]:
with open("results/bns_comparison_table.tex", "w") as fp:
    fp.write(df.style.to_latex(hrules=True, column_format="lccc"))

# Corner plot

In [None]:
cbc_param_labels = {
    'a_1': r'$\chi_1$', 
    'a_2': r'$\chi_2$', 
    'chirp_mass': r'$\mathcal{M}\;[\textrm{M}_\odot]$', 
    'dec': r'$\delta$',
    'ra': r'$\alpha$',
    'geocent_time': r'$t_\textrm{c}\;[\textrm{s}]$',
    'luminosity_distance': r'$d_\textrm{L}$',
    'mass_ratio': '$q$',
    'tilt_1': r'$\theta_1$',
    'tilt_2': r'$\theta_2$',
    'phi_12': r'$\phi_{12}$',
    'phi_jl': r'$\phi_{JL}$',
    'psi': r'$\psi$',
    'theta_jn': r'$\theta_{JN}$',
    'chi_1': r"$\chi_1$",
    "chi_2": r"$\chi_2$",
}

cbc_param_labels_wo_units = {
    'a_1': r'$\chi_1$', 
    'a_2': r'$\chi_2$', 
    'chirp_mass': r'$\mathcal{M}$', 
    'dec': r'$\delta$',
    'ra': r'$\alpha$',
    'geocent_time': r'$t_\textrm{c}$',
    'luminosity_distance': r'$d_\textrm{L}$',
    'mass_ratio': '$q$',
    'tilt_1': r'$\theta_1$',
    'tilt_2': r'$\theta_2$',
    'phi_12': r'$\phi_{12}$',
    'phi_jl': r'$\phi_{JL}$',
    'psi': r'$\psi$',
    'theta_jn': r'$\theta_{JN}$',
    'chi_1': r"$\chi_1$",
    "chi_2": r"$\chi_2$",
}


In [None]:
parameters = results_list[0][0].search_parameter_keys
print(parameters)
corner_labels = [cbc_param_labels.get(p) for p in parameters]
print(corner_labels)

In [None]:
with sns.plotting_context(
    rc={   
        "xtick.labelsize": 24,
        "ytick.labelsize": 24,
        "xtick.major.size" : 6,
        "xtick.major.width" : 1.0,
        "xtick.minor.size" : 3.0,
        "xtick.minor.width" : 1.0,
        "ytick.major.size" : 6,
        "ytick.major.width" : 1.0,
        "ytick.minor.size" : 3,
        "ytick.minor.width" : 1.0,
        "lines.linewidth": 2.0,
        "patch.linewidth": 2.0,
    }
):

    fig = bilby.core.result.plot_multiple(
        merge_results_list,
        parameters=parameters,
        bins=50,
        colours=["C2", "C1", "C0"],
        titles=False,
        fill_contours=False,
        smooth=0.95,
        label_kwargs=dict(fontsize=32),
        plot_datapoints=False,
        corner_labels=corner_labels,
        labelpad=0.12,
    )
    axs = fig.get_axes()
    for a in axs:
        try:
            a.get_legend().remove()
        except AttributeError:
            pass

In [None]:
fig.savefig("figures/bns_corner_plot.pdf")

In [None]:
fig.savefig("figures/bns_corner_plot.png", transparent=True)

# JS results

In [None]:
combinations = list(itertools.combinations(samplers, 2))
names = [f'{a}-{b}' for a, b in combinations]
parameters = merge_results["dynesty"].search_parameter_keys
labels = merge_results["dynesty"].parameter_labels_with_unit
n_samples = 5_000
threshold = 10 / n_samples
print(f"Threshold: {threshold}")
# Convert to millinats
conversion_factor = 1000

In [None]:
combinations

In [None]:
js_results = {}
for comb in combinations:
    print(comb)
    name = rf'\{comb[0].replace("-", "")}-\{comb[1].replace("-", "")}'
    js = {}
    std = {}
    for p in parameters:
        # val = js_bootstrap(p, post_dict[comb[0]], post_dict[comb[1]], 2000, 5, decimals=3)
        summary = calculate_js(
            merge_results[comb[0]].posterior[p],
            merge_results[comb[1]].posterior[p],
            nsamples=n_samples,
            base=np.e,  # nats
        )
        # Convert to desired units
        summary.median *= conversion_factor
        summary.plus *= conversion_factor
        summary.minus *= conversion_factor
        
        label = cbc_param_labels_wo_units.get(p)
        
        js[label] = f'${summary.median:.2f}' + '^{' + f'{summary.plus:.2f}' + '}_{-' + f'{summary.minus:.2f}' + '}$' 
    js_results[name] = js

Convert to a dataframe since this will format the results as table

In [None]:
df = pd.DataFrame(js_results)
df

Write the dataframe to a latex table.

In [None]:
with open("results/js_table.tex", "w") as fp:
    fp.write(df.style.to_latex(hrules=True))

Replace hrules to match IOP guidelines

In [None]:
for tex_file in glob.glob("results/*.tex"):
    print(tex_file)
    with open(tex_file, "r") as f:
        new_text = f.read()
    new_text = new_text.replace("toprule", "br")
    new_text = new_text.replace("midrule", "mr")
    new_text = new_text.replace("bottomrule", "br")
    with open(tex_file, "w") as f:
        f.write(new_text)