In [None]:
from efficient_probit_regression import settings
from efficient_probit_regression.datasets import BaseDataset, Covertype, KDDCup, Webspam
import numpy as np

import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

In [None]:
def get_results_df(dataset: BaseDataset, methods, p):
    df_list = []
    
    for method in methods:
        df = (
            pd.read_csv(settings.get_results_dir_p(p) / (dataset.get_name() + f"_{method}_p_{p}.csv"))
            .filter(items=["ratio", "size"])
            .groupby(["size"], as_index=False)
            .median()
            .assign(method=method)
        )
        df_list.append(df)

    return(pd.concat(df_list, ignore_index=True))

def get_results_df_raw(dataset: BaseDataset, methods, p):
    df_list = []
    
    for method in methods:
        df = (
            pd.read_csv(settings.get_results_dir_p(p) / (dataset.get_name() + f"_{method}_p_{p}.csv"))
            .filter(items=["run", "ratio", "size"])
            .assign(method=method)
        )
        df_list.append(df)

    return(pd.concat(df_list, ignore_index=True))

In [None]:
def make_report(dataset: BaseDataset, methods, p):
    df_median = get_results_df(dataset, methods, p)
    df_raw = get_results_df_raw(dataset, methods, p)

    fig = go.Figure()

    for color_index, method in enumerate(list(set(methods) - {"sgd"})):
        fig.add_trace(go.Scatter(
            x = df_raw.loc[df_raw["method"]==method]["size"],
            y = df_raw.loc[df_raw["method"]==method]["ratio"],
            name = method + "_raw",
            mode="markers",
            marker_color = px.colors.qualitative.Plotly[color_index]
        ))
        fig.add_trace(go.Scatter(
            x = df_median.loc[df_median["method"]==method]["size"],
            y = df_median.loc[df_median["method"]==method]["ratio"],
            name = method + "_median",
            mode="lines",
            marker_color = px.colors.qualitative.Plotly[color_index]
        ))

    if "sgd" in methods:
        median = df_median.loc[df_median["method"] == "sgd"]["ratio"].to_numpy()[0]
        fig.add_trace(go.Scatter(
            x = [np.min(df_raw["size"]), np.max(df_raw["size"])],
            y = [median, median],
            name = "sgd_median",
            mode="lines",
        ))

    fig.update_xaxes(title_text="size")
    fig.update_yaxes(title_text="ratio")
    fig.update_layout(title=f"{dataset.get_name()}, p={p}")

    # fig.write_html(f"report_{dataset.get_name()}_p_{p}.html")

    fig.show()

In [None]:
make_report(Covertype(), methods = ["uniform", "lewis", "leverage", "logit", "full_qr"], p=1)

In [None]:
# make_report(KDDCup(), methods = ["uniform", "leverage", "leverage_online", "sgd"])

In [None]:
# make_report(Webspam(), methods = ["uniform", "leverage", "leverage_online", "sgd"])