In [None]:
from efficient_probit_regression import settings
from efficient_probit_regression.datasets import BaseDataset, Covertype, KDDCup, Webspam

from scipy.stats import median_abs_deviation

import matplotlib.pyplot as plt
import matplotlib

import pandas as pd
import numpy as np

import seaborn as sns

In [None]:
if not settings.PLOTS_DIR.exists():
    settings.PLOTS_DIR.mkdir()

In [None]:
def get_results_df(dataset: BaseDataset, p, methods):
    df_list = []
    
    for method in methods:
        df = (
            pd.read_csv(settings.get_results_dir_p(p) / (dataset.get_name() + f"_{method}_p_{p}.csv"))
            .filter(items=["ratio", "size"])
            .groupby(["size"], as_index=False)
            .agg(
                median_ratio = pd.NamedAgg(column="ratio", aggfunc="median"), 
                std  = pd.NamedAgg(column="ratio", aggfunc="std"),
                mad  = pd.NamedAgg(column="ratio", aggfunc=median_abs_deviation),
                q75  = pd.NamedAgg(column="ratio", aggfunc=lambda x: np.quantile(x, q=0.75)),
                q25  = pd.NamedAgg(column="ratio", aggfunc=lambda x: np.quantile(x, q=0.25)),
            )
            .assign(method=method)
        )
        df_list.append(df)

    return(pd.concat(df_list, ignore_index=True))

In [None]:
def make_plot(dataset, p, methods, x_min, x_max, y_min, y_max, plot_bands=True, font_size=15, font_size_title=23):
    results_df = get_results_df(dataset, p, methods)

    # use TeX for typesetting
    plt.rcParams["text.usetex"] = True
    plt.rc("font", size=font_size)

    fig, ax = plt.subplots()

    colormap = matplotlib.cm.get_cmap(name="tab20")
    colors = {
        "uniform": colormap(0),
        "logit": colormap(2),
        "lewis": colormap(4),
        "leverage": colormap(6),
        "leverage_online": colormap(8),
    }

    labels = {
        "uniform": "Uniform",
        "logit": "Logit",
        "lewis": "Lewis",
        "leverage": f"L{p}S",
        "leverage_online": f"L{p}S-online",
    }

    titles = {
        "covertype": f"Covertype, p={p}",
        "kddcup": f"Kddcup, p={p}",
        "webspam": f"Webspam, p={p}",
    }

    for cur_method in methods:
        cur_results = results_df.loc[results_df["method"] == cur_method]
        ax.plot(
            cur_results["size"], 
            cur_results["median_ratio"],
            color=colors[cur_method],
            label=labels[cur_method],
        )
        if plot_bands:
            ax.fill_between(
                cur_results["size"],
                (cur_results["q25"] - cur_results["median_ratio"]) * 0.7413 + cur_results["median_ratio"],
                (cur_results["q75"] - cur_results["median_ratio"]) * 0.7413 + cur_results["median_ratio"],
                # cur_results["q25"],
                # cur_results["q75"],
                color=colors[cur_method],
                alpha=0.3
            )

    ax.set_xlim(left=x_min, right=x_max)
    ax.set_ylim(bottom=y_min, top=y_max)

    ax.set_xlabel("reduced size")
    ax.set_ylabel("median approximation ratio")

    ax.set_title(titles[dataset.get_name()], fontsize=font_size_title)

    legend = ax.legend(loc="upper right", frameon=True)

    fig.tight_layout()

    plt.savefig(settings.PLOTS_DIR / f"{dataset.get_name()}_ratio_plot_p_{p}.pdf")

    plt.show()

In [None]:
make_plot(Covertype(), p=1, methods=["uniform", "logit", "lewis", "leverage"], x_min=0, x_max=15000, y_min=1, y_max=1.04)
make_plot(KDDCup(), p=1, methods=["uniform", "logit", "lewis", "leverage"], x_min=0, x_max=30000, y_min=1, y_max=2.5)
make_plot(Webspam(), p=1, methods=["uniform", "logit", "lewis", "leverage"], x_min=0, x_max=15000, y_min=1, y_max=1.15)

make_plot(Covertype(), p=1.5, methods=["uniform", "leverage"], x_min=0, x_max=15000, y_min=1, y_max=1.04)
make_plot(KDDCup(), p=1.5, methods=["uniform", "leverage"], x_min=0, x_max=30000, y_min=1, y_max=10)
make_plot(Webspam(), p=1.5, methods=["uniform", "leverage"], x_min=0, x_max=15000, y_min=1, y_max=2)

make_plot(Covertype(), p=2, methods=["uniform", "leverage", "leverage_online"], x_min=0, x_max=15000, y_min=1, y_max=1.10)
make_plot(KDDCup(), p=2, methods=["uniform", "leverage", "leverage_online"], x_min=0, x_max=30000, y_min=1, y_max=3)
make_plot(Webspam(), p=2, methods=["uniform", "leverage", "leverage_online"], x_min=0, x_max=15000, y_min=1, y_max=2.5)

make_plot(Covertype(), p=5, methods=["uniform", "leverage"], x_min=0, x_max=15000, y_min=1, y_max=10)
make_plot(KDDCup(), p=5, methods=["uniform", "leverage"], x_min=0, x_max=15000, y_min=1, y_max=1000000000)
make_plot(Webspam(), p=5, methods=["uniform", "leverage"], x_min=0, x_max=15000, y_min=1, y_max=100000)