In [None]:
import pandas as pd
from efficient_probit_regression.datasets import BaseDataset, Covertype
from efficient_probit_regression import settings

import plotly.graph_objects as go

In [None]:
def make_report(dataset: BaseDataset, methods = ["uniform", "leverage", "leverage_online"]):
    df_list = []
    for cur_method in methods:
        cur_df = pd.read_csv(settings.RESULTS_DIR_BAYES / f"{dataset.get_name()}_mmd_{cur_method}.csv")
        cur_df["method"] = cur_method
        df_list.append(cur_df)

    df = pd.concat(df_list, ignore_index = True)

    df_median = df.groupby(["size", "method"], as_index=False).median()

    fig = go.Figure()

    for cur_method in methods:
        fig.add_trace(go.Scatter(
            x = df.loc[df["method"] == cur_method]["size"],
            y = df.loc[df["method"] == cur_method]["mmd"],
            name = cur_method + "_points",
            mode = "markers"
        ))

        fig.add_trace(go.Scatter(
            x = df_median.loc[df_median["method"] == cur_method]["size"],
            y = df_median.loc[df_median["method"] == cur_method]["mmd"],
            name = cur_method + "_median",
            mode = "lines"
        ))

    fig.update_xaxes(title_text="size")
    fig.update_yaxes(title_text="mmd")
    fig.update_layout(title=dataset.get_name())

    fig.show()

In [None]:
make_report(Covertype(), methods = ["uniform", "leverage"])