In [None]:
import os
import numpy as np
import json
import matplotlib.pyplot as plt

def open_violators(max_tree, leaf_cutoff, sample_rate):
    violators = {
        "kl_div": [],
        "mll": [],
        "mean_kl_div": [],
        "mean_mll": [],
        "std_kl_div": [],
        "std_mll": [],
        "max_kl_div": [],
        "min_mll": [],
    }
    for i in range(max_tree):
        baseline_dir = "/localdata/sorel/covertrees/test_set_baselines"
        baseline_file = f"tree_{leaf_cutoff}_{i}_baseline_{sample_rate}_loo_violators.json"
        with open(os.path.join(baseline_dir, baseline_file)) as f:
            loo_violators = json.load(f)
        for i, (_, observation, baseline) in enumerate(loo_violators):
            violators["kl_div"].append(observation["kl_div"])
            violators["mll"].append(observation["mll"])
            violators["mean_kl_div"].append(baseline["mean_kl_div"])
            violators["mean_mll"].append(baseline["mean_mll"])
            violators["std_kl_div"].append(baseline["std_kl_div"])
            violators["std_mll"].append(baseline["std_mll"])
            violators["max_kl_div"].append(baseline["max_kl_div"])
            violators["min_mll"].append(baseline["min_mll"])
            
    return {k:np.array(v) for k,v in violators.items()}

In [None]:
violators = {
    1000: open_violators(48, 500, 1000),
    10000: open_violators(48, 500, 10000),
    100000: open_violators(48, 500, 100000),
}

In [None]:
def eval_baseline_hyperparameters(all_violators, sequence_len, kl_str, kl_cor, mll_str, mll_cor, safety_margin=1.5):
    seq_violators = all_violators[sequence_len]
    x_kl_div = seq_violators["std_kl_div"]
    y_kl_div = safety_margin*seq_violators["kl_div"] - seq_violators["max_kl_div"]
    total_samples = len(y_kl_div)
    kl_selection = [y_kl_div - kl_cor*x_kl_div - kl_str > 0]
    kl_div_fp_count = np.array(kl_selection).sum()

    x_mll = seq_violators["std_mll"]
    y_mll = (seq_violators["min_mll"] - safety_margin*seq_violators["mll"])
    mll_selection = [y_mll - mll_cor*x_mll - mll_str > 0]
    mll_fp_count = np.array(mll_selection).sum()

    fig, axs = plt.subplots(2, 2)
    axs[1,0].scatter(x_kl_div[kl_selection], y_kl_div[kl_selection], label = f"False Positives {100*kl_div_fp_count/total_samples}%", color="red")
    axs[1,0].plot(x_kl_div[kl_selection], kl_cor*x_kl_div[kl_selection] + kl_str, label = f"KL Div Criterion {kl_cor} x + {kl_str}", color="orange")
    axs[1,0].set_title("KL Div False Positives")
    axs[1,0].set_ylabel("Node's KL Div - maximum KL Div from baseline")
    axs[1,0].set_xlabel("Std of Node's KL Div")
    axs[1,0].legend()
    axs[0,0].scatter(x_kl_div, y_kl_div)
    axs[0,0].plot(x_kl_div, kl_cor*x_kl_div + kl_str, label = f"KL Div Criterion {kl_cor} x + {kl_str}", color="orange")
    axs[0,0].set_title("All Potential KL Div Violators")
    axs[0,0].set_ylabel("Node's KL Div - maximum KL Div from baseline")
    axs[0,0].set_xlabel("Std of Node's KL Div")
    axs[0,0].legend()

    axs[1,1].scatter(x_mll[mll_selection], y_mll[mll_selection], label = f"False Positives {100*mll_fp_count/total_samples}%", color="red")
    axs[1,1].plot(x_mll[mll_selection], mll_cor*x_mll[mll_selection] + mll_str, label = f"MLL Criterion {mll_cor} x + {mll_str}", color="orange")
    axs[1,1].set_title("MLL False Positives")
    axs[1,1].set_ylabel("Minimum MLL from Baseline - Node's MLL")
    axs[1,1].set_xlabel("Std of Node's MLL")
    axs[1,1].legend()
    axs[0,1].scatter(x_mll, y_mll)
    axs[0,1].plot(x_mll, mll_cor*x_mll + mll_str, label = f"MLL Criterion {mll_cor} x + {mll_str}", color="orange")
    axs[0,1].set_title("All Potential MLL Violators")
    axs[0,1].set_ylabel("Minimum MLL from Baseline - Node's MLL")
    axs[0,1].set_xlabel("Std of Node's MLL")
    axs[0,1].legend()
    fig.set_size_inches(10, 8)
    fig.tight_layout()
    fig.savefig(f"{sequence_len}_violators.png", bbox_inches="tight", dpi= 1000)
    plt.show()

In [None]:
eval_baseline_hyperparameters(violators, 1000, kl_str=12,kl_cor=10,mll_str=80, mll_cor=1.3, safety_margin=2)


In [None]:
eval_baseline_hyperparameters(violators, 10000, kl_str=6.5,kl_cor=20,mll_str=100, mll_cor=1.4, safety_margin=2)

In [None]:
eval_baseline_hyperparameters(violators, 100000, kl_str=80,kl_cor=15,mll_str=100, mll_cor=1.9, safety_margin=2)