In [None]:
import seaborn as sns

from scoresbibm.utils.data_utils import query, get_summary_df, load_model
from scoresbibm.utils.plot import plot_metric_by_num_simulations, use_style,multi_plot

import matplotlib.pyplot as plt

import logging
logging.getLogger('matplotlib.font_manager').disabled = True


In [None]:
def filter_guidance_method(x, val="repaint"):
    cfg = x["cfg"]
    mask = cfg.apply(lambda x: eval(x)["method"]["posterior"]["sampling_method"] == val)

    return x[mask]

def filter_resampling_steps(x, num=0):
    cfg = x["cfg"]
    mask = cfg.apply(lambda x: eval(x)["method"]["posterior"]["resampling_steps"] ==num)

    return x[mask]

def filter_ids(x, ids):
    id = x["model_id"]
    mask = id.apply(lambda x: x in ids)

    return x[mask]

In [None]:
import pandas as pd
import seaborn as sns

In [None]:
vpsde_id = query("../../results_final/main_benchmark_all_cond_joint2",metric="none",method_sde_name="vpsde")["model_id"].tolist()
vesde_id = query("../../results_final/main_benchmark_all_cond_joint2",metric="none",method_sde_name="vesde")["model_id"].tolist()

In [None]:
df_guidance = query("../../results_final/main_benchmark_all_cond_joint2", metric="c2st", method="score_transformer_joint")

In [None]:
with use_style("pyloric"):
    fig, axes = plt.subplots(2,4, figsize=(10,3.), sharex=True, sharey=True)
    axes[0,0].set_yticks([0.5,1.])
    axes[1,0].set_yticks([0.5,1.])
    axes[0,0].set_yticklabels([0.5,1.])
    axes[1,0].set_yticklabels([0.5,1.])
    axes[0,0].set_ylim([0.5,1.])
    axes[1,0].set_ylim([0.5,1.])
    axes[0,0].set_ylabel("VESDE\n\nC2ST", x=-1.)
    axes[1,0].set_ylabel("VPSDE\n\nC2ST", x=-1.)

    tasks = ["tree_all_cond", "marcov_chain_all_cond", "two_moons_all_cond", "slcp_all_cond"]
    color_map = {"score_transformer_joint_repaint0": "#1e81b0", "score_transformer_joint_repaint5": "#76b5c5","score_transformer_joint_generalized_guidance0": "#e28743","score_transformer_joint_generalized_guidance5": "#eab676"}
    for i in range(4):
        df_guidance = query("../../results_final/main_benchmark_all_cond_joint2", metric="c2st", method="score_transformer_joint", task = tasks[i])
        df1 = filter_ids(filter_guidance_method(filter_resampling_steps(df_guidance, num=0), val="repaint"), vesde_id)
        df1["method"] = df1["method"].apply(lambda x: x + "_repaint0")

        df2 = filter_ids(filter_guidance_method(filter_resampling_steps(df_guidance, num=5), val="repaint"), vesde_id)
        df2["method"] = df2["method"].apply(lambda x: x + "_repaint5")

        df3 = filter_ids(filter_guidance_method(filter_resampling_steps(df_guidance, num=0), val="generalized_guidance"), vesde_id)
        df3["method"] = df3["method"].apply(lambda x: x + "_generalized_guidance0")

        df4 = filter_ids(filter_guidance_method(filter_resampling_steps(df_guidance, num=5), val="generalized_guidance"), vesde_id)
        df4["method"] = df4["method"].apply(lambda x: x + "_generalized_guidance5")

        df_preprocesed = pd.concat([df1, df2, df3, df4])

        sns.pointplot(data=df_preprocesed, x="num_simulations", y="value", hue="method", alpha=0.8, ax=axes[0, i], legend =False, palette=color_map)
        
        
    for i in range(4):
        df_guidance = query("../../results_final/main_benchmark_all_cond_joint2", metric="c2st", method="score_transformer_joint", task = tasks[i])
        df1 = filter_ids(filter_guidance_method(filter_resampling_steps(df_guidance, num=0), val="repaint"), vpsde_id)
        df1["method"] = df1["method"].apply(lambda x: x + "_repaint0")

        df2 = filter_ids(filter_guidance_method(filter_resampling_steps(df_guidance, num=5), val="repaint"), vpsde_id)
        df2["method"] = df2["method"].apply(lambda x: x + "_repaint5")

        df3 = filter_ids(filter_guidance_method(filter_resampling_steps(df_guidance, num=0), val="generalized_guidance"), vpsde_id)
        df3["method"] = df3["method"].apply(lambda x: x + "_generalized_guidance0")

        df4 = filter_ids(filter_guidance_method(filter_resampling_steps(df_guidance, num=5), val="generalized_guidance"), vpsde_id)
        df4["method"] = df4["method"].apply(lambda x: x + "_generalized_guidance5")

        df_preprocesed = pd.concat([df1, df2, df3, df4])

        sns.pointplot(data=df_preprocesed, x="num_simulations", y="value", hue="method", alpha=0.8, ax=axes[1, i], legend =False, palette=color_map)
        axes[1,i].set_xticklabels([r"$10^3$", r"$10^4$", r"$10^5$"])
        axes[1,i].set_xlabel("Number of simulations")
    fig.legend(["Repaint (r=0)", "Repaint (r=5)", "GGuidance (r=0)", "GGuidance (r=5)"], loc="center", bbox_to_anchor=(0.5, 1.), ncol=4)
    fig.savefig("guidance.svg", bbox_inches="tight")
    plt.show()