In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def load_many(path, path2):
    if not isinstance(path, list):
        path = [path]
    test_outs = []
    for p in path:
        f = np.load(p)
        valid_out = f["valid_output"].flatten()
        test_outs.append(f["test_output"].flatten())
    test_out = np.concatenate(tuple(test_outs))
    if not isinstance(path2, list):
        path2 = [path2]
    true_outs = []
    for p in path2:
        true = np.load(p, allow_pickle=True).item()['y']
        true_outs.append(true)
    true_out = np.concatenate(tuple(true_outs))
    test_out += 0.00001 * np.random.randn(*test_out.shape)
    return true_out, valid_out, test_out

def get_error(path_to_preds, path_to_true, method, minute):
    true_out, _, test_out = load_many(path_to_preds, path_to_true)
    df = pd.DataFrame({"true": true_out, "preds": test_out}).sample(frac=1)
    cutoff = df["preds"].quantile(.9)
    df = df.sort_values(by = ["true"])
    df["accepted"] = (df["preds"] < cutoff).astype(int)
    df["true_rank"] = (df["true"].rank(method="first") * 500 / (len(df)) * 0.99999).astype(int)
    df2 = df.groupby(["true_rank"]).mean()
    q = np.linspace(0, 100, len(df2))
    plt.plot(q, 1 - df2.accepted)
    plt.plot([90, 90], [0, 1.0], label="Heavy Hitter Threshold")
    plt.xlabel("Frequency Percentile")
    plt.ylabel("Screening Probability")
    plt.title("Screened Rate")
    plt.savefig(f"{method}_acceptance_rate_minute{minute:02d}.png")
    plt.legend()
    plt.show()

for minute in [8, 29, 59]:
    for method in ["bn-8-HsuRNN-True-ckpts-forwards-more", "bn-64-HsuRNN-True-ckpts-forwards-more", 
                    "l1-HsuRNN-False-ckpts-forwards-more", "l1-HsuRNN-True-ckpts-forwards-more",
                    "log_mse-HsuRNN-False-ckpts-forwards-more", "log_mse-HsuRNN-True-ckpts-forwards-more"]:    
        for trial in [1]:
            path_to_preds = f"tb_logs_modded/{method}/trial{trial}/lightning_logs/predictions{minute:02d}_res.npz"
            path_to_true = f"equinix-chicago.dirA.20160121-13{minute:02d}00.ports.npy" 
            get_error(path_to_preds, path_to_true, method, minute)
    