In [None]:
import pandas as pd
from efficient_probit_regression.datasets import BaseDataset, Covertype
from efficient_probit_regression import settings

import plotly.graph_objects as go
import plotly.express as px

In [None]:
def make_report(dataset: BaseDataset, methods = ["uniform", "leverage", "leverage_online"], measure_type="mmd"):
    """
    measure_type can be mmd, norm, or matrix_norm
    """
    df_list = []
    for cur_method in methods:
        cur_df = pd.read_csv(settings.RESULTS_DIR_BAYES / f"{dataset.get_name()}_{measure_type}_{cur_method}.csv")
        cur_df["method"] = cur_method
        df_list.append(cur_df)

    df = pd.concat(df_list, ignore_index = True)

    df_median = df.groupby(["size", "method"], as_index=False).median()

    fig = go.Figure()

    for color_index, cur_method in enumerate(methods):
        fig.add_trace(go.Scatter(
            x = df.loc[df["method"] == cur_method]["size"],
            y = df.loc[df["method"] == cur_method][measure_type],
            name = cur_method + "_points",
            mode = "markers",
            marker_color = px.colors.qualitative.Plotly[color_index]
        ))

        fig.add_trace(go.Scatter(
            x = df_median.loc[df_median["method"] == cur_method]["size"],
            y = df_median.loc[df_median["method"] == cur_method][measure_type],
            name = cur_method + "_median",
            mode = "lines",
            marker_color = px.colors.qualitative.Plotly[color_index]
        ))

    axis_titles = {
        "mmd": "mmd", 
        "norm": "mean difference",
        "matrix_norm": "L2 matrix norm"
    }

    fig.update_xaxes(title_text="size")
    fig.update_yaxes(title_text=axis_titles[measure_type])
    fig.update_layout(title=dataset.get_name() + " - " + axis_titles[measure_type])

    # fig.write_html(f"report_{measure_type}.html")

    fig.show()

In [None]:
make_report(Covertype(), methods = ["uniform", "leverage", "leverage_online"], measure_type="mmd")

In [None]:
make_report(Covertype(), methods = ["uniform", "leverage", "leverage_online"], measure_type="norm")

In [None]:
make_report(Covertype(), methods = ["uniform", "leverage", "leverage_online"], measure_type="matrix_norm")