In [None]:
%env JAX_PLATFORM_NAME=cpu
from tqdm.auto import tqdm
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import odds_datasets
from balif import Balif
import pandas as pd
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")
import plotly.io as pio
pio.templates.default = "seaborn"
import plotly.express as px

In [None]:
paths = {
    "Margin": lambda ds: f"results/margin_query/{ds}_BALIF.npy",
    #"MEBALIF": lambda ds: f"results/margin_query/{ds}_EBALIF.npy",
    "Random": lambda ds: f"results/random_query/{ds}_BALIF.npy",
    #"REBALIF": lambda ds: f"results/random_query/{ds}_EBALIF.npy",
}

In [None]:
df = pd.DataFrame()
ap = {}
for model_name, save_path_fn in paths.items():
    ap[model_name] = {}
    for i, dataset_name in enumerate((pbar := tqdm((odds_datasets.datasets_names)))):
        ap[model_name][dataset_name] = np.load(save_path_fn(dataset_name), allow_pickle=True).T

In [None]:
datasets = {
    dataset_name: odds_datasets.load(dataset_name)
    for dataset_name in odds_datasets.datasets_names
}
datasets = {
    name: {"contamination": y.mean(), "points": X.shape[0], "features": X.shape[1]}
    for name, (X, y) in datasets.items()
}
#f"{100*y.mean():.1f}%"
contamination = [(prop["contamination"], name) for name, prop in datasets.items()]

dataset_names = [(name, f"{name}") for c, name in sorted(contamination)]
dataset_names

In [None]:
df = pd.DataFrame()
for i, (dataset_name, label) in enumerate((pbar := tqdm(dataset_names))):
    for model_name, save_path_fn in paths.items():
        ap = np.load(save_path_fn(dataset_name), allow_pickle=True).T
        for queries, ap_q in (
            [("0", ap[0])]
            + [(f"{i}%", ap[1 + int(i * len(ap) / 10)]) for i in range(1, 10)]
            + [("10%", ap[-1])]
        ):
            df = df._append(
                {
                    "dataset": label,
                    "strategy": model_name,
                    "ap_mean": ap_q.mean(),
                    "ap_std": ap_q.std(),
                    "labels": queries,
                },
                ignore_index=True,
            )

In [None]:
fig = px.line_polar(
    df,
    r="ap_mean",
    theta="dataset",
    color="strategy",
    # line_dash="labels",
    # line_dash_sequence=["dot", "dashdot", "dash", "solid"],
    title=f"AP evolution",
    height=800,
    width=800,
    line_close=True,
    markers=True,
    range_r=[0, 1.01],
    animation_frame="labels"
)
fig["layout"].pop("updatemenus") # optional, drop animation buttons
fig.write_html("file.html")
fig.show()

In [None]:
plt.figure(figsize=(20, 30))
plt.subplot(6,3,len(odds_datasets.datasets_names)+1)
for strat, configs in model_configs.items():
    for label in configs.keys():
        plt.plot(0, 0, label=f"{label} {strat}")
plt.title("")
plt.legend(loc="center")
plt.axis("off")

for i, dataset_name in enumerate((pbar := tqdm((odds_datasets.datasets_names)))):
    plt.subplot(6,3,i+1)
    for strat, configs in model_configs.items():
        for label in configs.keys():
            save_path = f"results/{strat}/{dataset_name}_{label}.npy"
            ap = np.load(save_path, allow_pickle=True)
            ap_mean, ap_std = ap.mean(axis=0), ap.std(axis=0)
            plt.semilogx(1 + jnp.arange(len(ap_mean)), ap_mean)
            plt.fill_between(
                1 + jnp.arange(len(ap_mean)),
                jnp.maximum(0, ap_mean - ap_std),
                jnp.minimum(1, ap_mean + ap_std),
                alpha=0.3,
            )
    plt.title(dataset_name)
    data, labels = odds_datasets.load(dataset_name)
    plt.xlabel(f"Queries (10% ={1+len(data)//10})")
    if i % 3 == 0:
        plt.ylabel("Avg. Precision")
    plt.ylim(-0.05, 1.05)
    plt.xlim(1, 1 + data.shape[0] / 10)
    plt.xticks([1, 1 + data.shape[0] / 100, 1 + data.shape[0] / 10], labels=["0%", "1%", "10%"])
    plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
    plt.grid()
plt.tight_layout()
plt.savefig(f"figures/ap_evolution/ap_all.pdf", bbox_inches="tight")
plt.show()