In [None]:
from efficient_probit_regression import settings
from efficient_probit_regression.datasets import BaseDataset, Covertype, KDDCup, Webspam

from scipy.stats import median_abs_deviation

import matplotlib.pyplot as plt
import matplotlib

import pandas as pd
import numpy as np

import seaborn as sns

In [None]:
if not settings.PLOTS_DIR.exists():
    settings.PLOTS_DIR.mkdir()

In [None]:
def get_results_df(dataset: BaseDataset, metric, methods):
    df_list = []
    
    for method in methods:
        df = (
            pd.read_csv(settings.RESULTS_DIR_BAYES / (dataset.get_name() + f"_{metric}_{method}.csv"))
            .groupby(["size"], as_index=False)
            .agg(
                median = pd.NamedAgg(column=metric, aggfunc="median"), 
                q_upper  = pd.NamedAgg(column=metric, aggfunc=lambda x: np.quantile(x, q=0.75)),
                q_lower  = pd.NamedAgg(column=metric, aggfunc=lambda x: np.quantile(x, q=0.25)),
                count  = pd.NamedAgg(column=metric, aggfunc="count"),
            )
            .assign(method=method)
        )
        df_list.append(df)

    return(pd.concat(df_list, ignore_index=True))

get_results_df(Covertype(), metric="mmd", methods=["leverage"])

In [None]:
def make_plot(dataset, metric, methods, x_min=None, x_max=None, y_min=None, y_max=None, plot_bands=True, font_size=15, font_size_title=23):
    results_df = get_results_df(dataset, metric, methods)

    # use TeX for typesetting
    plt.rcParams["text.usetex"] = True
    plt.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}'] # for math symbols
    plt.rc("font", size=font_size)

    fig, ax = plt.subplots()

    colormap = matplotlib.cm.get_cmap(name="tab20")
    colors = {
        "uniform": colormap(0),
        "leverage": colormap(6),
        "leverage_online": colormap(8),
    }

    labels = {
        "uniform": "uniform",
        "leverage": "two pass",
        "leverage_online": "online",
    }

    y_labels = {
        "mmd": "MMD",
        "norm": r'$\lVert \mu_\beta - \widetilde{\mu}_\beta \rVert_2$',
        "matrix_norm": r'$\mathbf{\lVert \Sigma_\beta - \widetilde{\Sigma}_\beta \rVert_2}$',
    }

    titles = {
        "covertype": "Covertype",
        "kddcup": "Kddcup",
        "webspam": "Webspam",
    }

    for cur_method in methods:
        cur_results = results_df.loc[results_df["method"] == cur_method]
        ax.plot(
            cur_results["size"], 
            cur_results["median"],
            color=colors[cur_method],
            label=labels[cur_method],
        )
        if plot_bands:
            # ax.fill_between(
            #     cur_results["size"],
            #     cur_results["q_lower"],
            #     cur_results["q_upper"],
            #     color=colors[cur_method],
            #     alpha=0.3
            # )
            ax.errorbar(
                x=cur_results["size"],
                y=cur_results["median"],
                yerr = (cur_results["median"] - cur_results["q_lower"], cur_results["q_upper"] - cur_results["median"]),
                color=colors[cur_method],
                alpha=0.3
            )

    ax.set_xlim(left=x_min, right=x_max)
    ax.set_ylim(bottom=y_min, top=y_max)

    ax.set_xlabel("reduced size")
    ax.set_ylabel(y_labels[metric])

    ax.set_title(titles[dataset.get_name()], fontsize=font_size_title)

    legend = ax.legend(loc="upper right", frameon=True)

    fig.tight_layout()

    # plt.savefig(settings.PLOTS_DIR / f"{dataset.get_name()}_bayes_plot_{metric}.pdf")

    plt.show()

In [None]:
make_plot(dataset=Covertype(), metric="mmd", methods=["uniform", "leverage", "leverage_online"], y_min=0, y_max=1000)
make_plot(dataset=KDDCup(), metric="mmd", methods=["uniform", "leverage", "leverage_online"], y_min=0, y_max=1000)
make_plot(dataset=Webspam(), metric="mmd", methods=["uniform", "leverage", "leverage_online"], y_min=0, y_max=1000)

make_plot(dataset=Covertype(), metric="norm", methods=["uniform", "leverage", "leverage_online"], y_min=0, y_max=50)
make_plot(dataset=KDDCup(), metric="norm", methods=["uniform", "leverage", "leverage_online"], y_min=0, y_max=50)
make_plot(dataset=Webspam(), metric="norm", methods=["uniform", "leverage", "leverage_online"], y_min=0, y_max=50)

make_plot(dataset=Covertype(), metric="matrix_norm", methods=["uniform", "leverage", "leverage_online"], y_min=0, y_max=100)
make_plot(dataset=KDDCup(), metric="matrix_norm", methods=["uniform", "leverage", "leverage_online"], y_min=0, y_max=20)
make_plot(dataset=Webspam(), metric="matrix_norm", methods=["uniform", "leverage", "leverage_online"], y_min=0, y_max=100)