In [None]:
from sbi_feature_importance.experiment_helper import SimpleDB
from sbi_feature_importance.utils import extract_tags, skip_dims
from sbi_feature_importance.analysis import compare_kls
import torch

In [None]:
data_db = SimpleDB("../results/fig2", "r")
HH_db = SimpleDB("../results/fig3", "r")
HHdata_db = SimpleDB("../results/HH1M", "r")

In [3]:
print("Model & Method & Training time & Sampling time & Total time & KL" + r" \\")
print("\hline")
print("\hline")
row_names = {"posthoc":"FSLM", "direct": "NLE", "default": "NLE", "1_1_00":"FSLM"}

# For GLM
lines = data_db.query("timings_edit").split(" \n")[:-1]
dct = {line.split(":")[0]:float(line.split(": ")[1]) for line in lines}
tags = extract_tags(dct)

for i, method in enumerate(["direct", "posthoc"]):
    for j, sampling_method in enumerate(["rejection"]):
        if method == "posthoc" and sampling_method == "mcmc":
            seeds = range(1,11)
        else:
            seeds = range(9)
        times = {key:val for key, val in dct.items() if f"{method}_{sampling_method}" in key and "fixed" in key}

        train_times = {key:t for key, t in times.items() if "train" in key}
        train_times = torch.tensor([[val for key, val in train_times.items() if f"_{i}_" in key] for i in seeds])
        sample_times = {key:t for key, t in times.items() if not "train" in key}
        sample_times = torch.tensor([[val for key, val in sample_times.items() if f"_{i}_" in key] for i in seeds])
        ttrain = f"{float(train_times.sum(1).mean()) / 60 :.2f} $\pm$ {float(train_times.sum(1).std()) / 60 :.2f}"
        tsample = f"{float(sample_times.sum(1).mean()) / 60 :.2f} $\pm$ {float(sample_times.sum(1).std()) / 60 :.2f}"
        ttotal = f"{float((sample_times.sum(1) + train_times.sum(1)).mean()) / 60 :.2f} $\pm$ {float((sample_times.sum(1) + train_times.sum(1)).std()) / 60 :.2f}"
        
        acc_kls = torch.zeros(4,9)
        for i, subset in enumerate(skip_dims([0,1,2,3])):
            selection = {key: val for key,val in data_db.find(f"{method}_{sampling_method}_").items() if f"{subset}" in key and "fixed" in key and not "_1_" in key and not "posterior" in key}
            base_sample = data_db.query(f"direct_{sampling_method}_fixed_1_{subset}")
            samples = list(selection.values())
            kls = compare_kls(samples, base_sample)
            acc_kls[i] = kls
        kls = f"{float(acc_kls.mean()):.2f} $\pm$ {float(acc_kls.std()):.2f}"
        
        if (i,j) == (0,0):
            multirow = "\multirow{2}{*}{GLM [min]}"
        else:
            multirow = " "
        print(f"{multirow} & {row_names[method]} & {ttrain} & {tsample} & {ttotal} & {kls}" + r" \\")

print("\hline")
print("\hline")

# For HH
lines = HH_db.query("timings").split(" \n")[:-1]
dct = {line.split(":")[0]:float(line.split(": ")[1]) for line in lines}
tags = extract_tags(dct)
subset = [0, 1, 3, 8, 13, 18, 19, 21, 22]

for i, method in enumerate(["default", "1_1_00"]):
    times = {key:val for key, val in dct.items() if f"{method}" in key}

    train_times = {key:t for key, t in times.items() if "train" in key}
    train_times = torch.tensor([[val for key, val in train_times.items() if f"mcmc_{i}_" in key] for i in range(10)])
    sample_times = {key:t for key, t in times.items() if not "train" in key}
    sample_times = torch.tensor([[val for key, val in sample_times.items() if f"mcmc_{i}_" in key] for i in range(10)])
    ttrain = f"{float(train_times.sum(1).mean()) / 3600 :.2f} $\pm$ {float(train_times.sum(1).std()) / 3600 :.2f}"
    tsample = f"{float(sample_times.sum(1).mean()) / 3600 :.2f} $\pm$ {float(sample_times.sum(1).std()) / 3600 :.2f}"
    ttotal = f"{float((sample_times.sum(1) + train_times.sum(1)).mean()) / 3600 :.2f} $\pm$ {float((sample_times.sum(1) + train_times.sum(1)).std()) / 3600 :.2f}"

    # items = HH_db.find(method)
    # samples = {key:val for key, val in items.items() if not "posterior" in key}
    # subset_samples = {key:val for key, val in samples.items() if str(subset) in key and not "mcmc_0_" in key}
    # subset_samples = list(subset_samples.values())
    # base_sample = HH_db.query(f"direct_default_mcmc_0_{subset}")
    # kls = compare_kls(subset_samples, base_sample)
    # kls = f"{float(kls.mean()):.2f} $\pm$ {float(kls.std()):.2f}"
    
    items = HH_db.find(method)
    samples = {key:val for key, val in items.items() if not "posterior" in key}
    agg_kls = torch.zeros(10,9)
    for i, subset in enumerate(skip_dims([0, 1, 2, 3, 8, 13, 18, 19, 21, 22])):
        subset_samples = {key:val for key, val in samples.items() if str(subset) in key and not "mcmc_0_" in key}
        subset_samples = list(subset_samples.values())
        base_sample = HH_db.query(f"direct_default_mcmc_0_{subset}")
        agg_kls[i] = compare_kls(subset_samples, base_sample)
    
    kls = f"{float(agg_kls.mean()):.2f} $\pm$ {float(agg_kls.std()):.2f}"

    if i == 0:
        multirow = "\multirow{2}{*}{HH [h]}"
    else:
        multirow = " "
    print(f"{multirow} & {row_names[method]} & {ttrain} & {tsample} & {ttotal} & {kls}" + r" \\")

Model & Method & Training time & Sampling time & Total time & KL \\
\hline
\hline
  & NLE & 3.57 $\pm$ 0.79 & 0.33 $\pm$ 0.04 & 3.89 $\pm$ 0.80 & 0.06 $\pm$ 0.10 \\
  & FSLM & 0.64 $\pm$ 0.21 & 0.66 $\pm$ 0.10 & 1.30 $\pm$ 0.26 & 0.04 $\pm$ 0.09 \\
\hline
\hline
  & NLE & 84.20 $\pm$ 9.28 & 10.34 $\pm$ 0.54 & 94.53 $\pm$ 9.24 & 1.90 $\pm$ 3.99 \\
  & FSLM & 7.21 $\pm$ 1.49 & 13.27 $\pm$ 1.04 & 20.48 $\pm$ 2.08 & 1.53 $\pm$ 1.32 \\
