In [None]:
from efficient_probit_regression.datasets import BaseDataset, Covertype, KDDCup, Webspam
from efficient_probit_regression import settings

import matplotlib.pyplot as plt

import seaborn as sns

import numpy as np
import pandas as pd

In [None]:
def make_plot(dataset: BaseDataset, methods, beta_list, legend = True):
    run = 1
    size = 15000

    original_df = (
        pd.read_csv(settings.RESULTS_DIR_BAYES / f"{dataset.get_name()}_sample_full.csv")
        .melt()
        .assign(method="original")
    )

    method_names = {
        "original": "original",
        "uniform": "uniform",
        "leverage": "two pass",
        "leverage_online": "online",
    }

    df_list = [original_df]
    for method in methods:
        cur_df = (
            pd.read_csv(settings.RESULTS_DIR_BAYES / f"{dataset.get_name()}_sample_{method}_run_{run}.csv")
            .query("size == @size")
            .drop(columns=["run", "size", "reduction_time_s", "total_time_s"])
            .melt()
            .assign(method=method_names[method])
        )
        df_list.append(cur_df)

    df = pd.concat(df_list, ignore_index=True)
    df = (
        df.assign(index = lambda x: x.variable.str.split("_", expand=True)[1].astype(int))
        .drop(columns=["variable"])
        .query("index in @beta_list")
    )

    plt.rcParams["text.usetex"] = True
    plt.rc("font", size=15)

    tab20 = sns.color_palette("tab20")

    fig, ax = plt.subplots()
    sns.boxplot(data=df, x="index", y="value", hue="method", ax=ax, palette=sns.color_palette([tab20[4], tab20[0], tab20[6], tab20[8]]))

    ax.set_title(r"$\beta$-" f"{dataset.get_name().capitalize()}, size = {size}", fontsize=23)

    if not legend:
        ax.get_legend().remove()

    fig.tight_layout()

    fig.savefig(settings.PLOTS_DIR / f"{dataset.get_name()}_coefficients_{min(beta_list)}.pdf")

    fig.show()


# make_plot(Covertype(), methods=["uniform", "leverage", "leverage_online"], beta_list=[0])
# make_plot(Covertype(), methods=["uniform", "leverage", "leverage_online"], beta_list=[1, 2, 3, 4, 5], legend=False)
# make_plot(Covertype(), methods=["uniform", "leverage", "leverage_online"], beta_list=[6, 7, 8, 9], legend=False)
# make_plot(Covertype(), methods=["uniform", "leverage", "leverage_online"], beta_list=[10, 11, 12, 13, 14], legend=False)
# make_plot(Covertype(), methods=["uniform", "leverage", "leverage_online"], beta_list=[15, 16, 17, 18, 19], legend=False)
# make_plot(Covertype(), methods=["uniform", "leverage", "leverage_online"], beta_list=[20, 21, 22, 23, 24], legend=False)


# make_plot(Covertype(), methods=["uniform", "leverage", "leverage_online"], beta_list=[25, 26, 27, 28, 29], legend=True)
# make_plot(Covertype(), methods=["uniform", "leverage", "leverage_online"], beta_list=[30, 31, 32, 33, 34], legend=False)
# make_plot(Covertype(), methods=["uniform", "leverage", "leverage_online"], beta_list=[35, 36, 37, 38, 39], legend=False)
# make_plot(Covertype(), methods=["uniform", "leverage", "leverage_online"], beta_list=[40, 41, 42, 43, 44], legend=False)
# make_plot(Covertype(), methods=["uniform", "leverage", "leverage_online"], beta_list=[45, 46, 47, 48, 49], legend=False)
# make_plot(Covertype(), methods=["uniform", "leverage", "leverage_online"], beta_list=[50, 51, 52, 53, 54], legend=False)