In [None]:
import glob
import os
import sys
import datetime

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

basedir = "../../"
sys.path.append(basedir)
from utils import configure_plotting, load_json

configure_plotting(basedir)

In [None]:
n_pool = [1, 2, 4, 8, 16]

In [None]:
_results = {}
_baseline_results = {}
for p in n_pool:
    path = os.path.join(f"n_pool_{p}", "result", "*_nessai/result.json")
    print(path)
    files = glob.glob(path)
    print(f"Found {len(files)} files")
    _results[p] = [load_json(f) for f in files]
    path = os.path.join(f"n_pool_{p}_baseline", "result", "*_nessai/result.json")
    print(path)
    files = glob.glob(path)
    print(f"Found {len(files)} files")
    _baseline_results[p] = [load_json(f) for f in files]
print("Done")

In [None]:
def get_time(string):
    t = datetime.datetime.strptime(string, "%H:%M:%S.%f") - datetime.datetime(
        1900, 1, 1
    )
    return t.total_seconds()

In [None]:
results = dict(
    sampling_time=None,
    add_samples_time=None,
    update_level_time=None,
    update_ns_time=None,
    redraw_time=None,
    likelihood_evaluation_time=None,
)
baseline_results = dict(
    sampling_time=None,
    likelihood_evaluation_time=None,
    training_time=None,
    population_time=None,
)
for r, _r in [(results, _results), (baseline_results, _baseline_results)]:
    for k in r.keys():
        if r[k] is None:
            r[k] = pd.DataFrame(columns=["mean", "std"])
        for res in _r.values():
            try:
                dd = dict(
                    mean=np.nanmean([d[k] for d in res]) / 60,
                    std=np.nanstd([d[k] for d in res]) / 60,
                )
            except TypeError:
                get_time(res[0][k])
                dd = dict(
                    mean=np.nanmean([get_time(d[k]) for d in res]) / 60,
                    std=np.nanstd([get_time(d[k]) for d in res]) / 60,
                )
            r[k] = r[k].append(dd, ignore_index=True)
#             r[k] = pd.concat([r[k], pd.DataFrame(dd)])

In [None]:
baseline_results["population_time"] = (
    baseline_results["population_time"] - baseline_results["likelihood_evaluation_time"]
)
results["add_samples_time"] = (
    results["add_samples_time"] - results["likelihood_evaluation_time"]
)

In [None]:
likelihood_fraction = results["likelihood_evaluation_time"]["mean"] / (results["sampling_time"]["mean"] + results["redraw_time"]["mean"])
print(f"Fraction of time spent evaluating the likelihood: {likelihood_fraction}")
print(1 - likelihood_fraction)

In [None]:
likelihood_fraction = baseline_results["likelihood_evaluation_time"]["mean"] / (baseline_results["sampling_time"]["mean"] )
print(f"Fraction of time spent evaluating the likelihood: {likelihood_fraction}")
print(1 - likelihood_fraction)

In [None]:
bar_labels = dict(
    sampling_time="Total",
    add_samples_time="Adding samples",
    redraw_time="Resampling",
    update_level_time="Training",
    update_ns_time="Meta-proposal",
    likelihood_evaluation_time="Likelihood",
    training_time="Training",
)

In [None]:
fig = plt.figure(dpi=200)
hatch = ["/", "\\", "|", ".", "x", "x", "o", "O", ".", "*"]
xloc = np.arange(len(n_pool))
width = 0.4
sep = width / 2
order = [
    "update_level_time",
    "likelihood_evaluation_time",
]
base_order = ["training_time", "likelihood_evaluation_time"]#, "population_time"]
colours = sns.color_palette("RdYlBu", n_colors=7)
baseline_colours = [colours[-1], colours[-2], colours[-3]]
ins_colours = colours
bottom = np.zeros_like(n_pool)


plt.errorbar(
    xloc - sep, baseline_results["sampling_time"]["mean"],
    yerr=baseline_results["sampling_time"]["std"],
    ls='', marker='x', color=baseline_colours[0], label=r"Total"
)
other_height = (
    baseline_results["sampling_time"]["mean"]
    - baseline_results["likelihood_evaluation_time"]["mean"]
    - baseline_results["training_time"]["mean"]
)
base_bottom = np.zeros_like(n_pool)
plt.bar(
        xloc - sep,
        other_height,
        bottom=base_bottom,
        label="Other",
        linewidth=0.0,
        width=width,
#         hatch="...",
        facecolor=baseline_colours[0],
    )
base_bottom += other_height

for i, k in enumerate(base_order):
    base_res = baseline_results[k]
    height = base_res["mean"] #/ baseline_results["sampling_time"]["mean"]
    plt.bar(
        xloc - sep,
        height,
        bottom=base_bottom,
        label=bar_labels.get(k),
        linewidth=0.0,
        width=width,
#         hatch="...",
        facecolor=baseline_colours[i + 1],
    )

    base_bottom = base_bottom + height


total_error = np.sqrt(np.array(results["sampling_time"]["std"].to_numpy() ** 2.0 + results["redraw_time"]["std"].to_numpy() ** 2.0, dtype=float))
plt.errorbar(
    xloc + sep, results["sampling_time"]["mean"] + results["redraw_time"]["mean"],
    yerr=total_error,
    ls='', marker='.', color=ins_colours[0], label=r"Total"
)

other_height = (
    results["sampling_time"]["mean"]
    + results["redraw_time"]["mean"]
    - results["update_level_time"]["mean"]
    - results["likelihood_evaluation_time"]["mean"]
)



hatch = ["//", r"\\", "||"][::-1]
plt.bar(
        xloc + sep,
        other_height,
        bottom=bottom,
        label="Other",
        linewidth=0.0,
        width=width,
        facecolor=ins_colours[0],
#         edgecolor=ins_colours[2],
        hatch=hatch[0],
    )
bottom += other_height
for i, k in enumerate(order):
    res = results[k]
    height = res["mean"] #/ (
#         results["sampling_time"]["mean"] + results["redraw_time"]["mean"]
#     )
    plt.bar(
        xloc + sep,
        height,
        bottom=bottom,
        label=bar_labels.get(k),
        linewidth=0.0,
        width=width,
        facecolor=ins_colours[i + 1],
#         edgecolor=ins_colours[i],
        hatch=hatch[i + 1],
        fill=True,
    )
    bottom = bottom + height
    

plt.ylabel("Wall time [min]")
plt.xlabel("Number of cores")
plt.xticks(xloc, labels=n_pool)
plt.tick_params(which="minor", top=False, bottom=False)

handles, _ = plt.gca().get_legend_handles_labels()



base_handles, ins_handles = handles[:4], handles[4:]
labels = ["Total", "Likelihood", "Training", "Other"]

print(base_handles, ins_handles)

base_handles = [base_handles[0]] + base_handles[1:][::-1]
ins_handles = [ins_handles[0]] + ins_handles[1:][::-1]

title_handle = matplotlib.patches.Rectangle((0,0), 1, 1, fill=False, edgecolor='none',
                                 visible=False)
handles = (
    [title_handle]
    + base_handles
    +[title_handle]
    + ins_handles
)
labels = (
    [r"\texttt{nessai}"]
    + labels
    +[r"\texttt{i-nessai}"]
    + labels
)
    
plt.tight_layout()
plt.legend(handles=handles, labels=labels, loc="upper right", ncol=2)
fig.savefig("figures/parallelisation.pdf", bbox_inches="tight")
# plt.show()