In [None]:
from efficient_probit_regression import settings
from efficient_probit_regression.datasets import BaseDataset, Covertype, KDDCup, Webspam

import pandas as pd
import plotly.graph_objects as go

In [None]:
def get_results_df(dataset: BaseDataset, methods):
    df_list = []
    
    for method in methods:
        df = (
            pd.read_csv(settings.RESULTS_DIR / (dataset.get_name() + f"_{method}.csv"))
            .filter(items=["ratio", "size"])
            .groupby(["size"], as_index=False)
            .median()
            .assign(method=method)
        )
        df_list.append(df)

    return(pd.concat(df_list, ignore_index=True))

def get_results_df_raw(dataset: BaseDataset, methods):
    df_list = []
    
    for method in methods:
        df = (
            pd.read_csv(settings.RESULTS_DIR / (dataset.get_name() + f"_{method}.csv"))
            .filter(items=["run", "ratio", "size"])
            .assign(method=method)
        )
        df_list.append(df)

    return(pd.concat(df_list, ignore_index=True))

In [None]:
def make_report(dataset: BaseDataset, methods):
    df_median = get_results_df(dataset, methods)
    df_raw = get_results_df_raw(dataset, methods)

    fig = go.Figure()

    for method in methods:
        fig.add_trace(go.Scatter(
            x = df_raw.loc[df_raw["method"]==method]["size"],
            y = df_raw.loc[df_raw["method"]==method]["ratio"],
            name = method + "_raw",
            mode="markers"
        ))
        fig.add_trace(go.Scatter(
            x = df_median.loc[df_median["method"]==method]["size"],
            y = df_median.loc[df_median["method"]==method]["ratio"],
            name = method + "_median",
            mode="lines"
        ))

    fig.update_xaxes(title_text="size")
    fig.update_yaxes(title_text="ratio")
    fig.update_layout(title=dataset.get_name())

    fig.show()

In [None]:
make_report(Covertype(), methods = ["uniform", "leverage", "leverage_online"])

In [None]:
make_report(KDDCup(), methods = ["uniform", "leverage", "leverage_online"])

In [None]:
make_report(Webspam(), methods = ["uniform", "leverage", "leverage_online"])