In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
import seaborn as sns
from src.plots.latex import set_size, update_rcParams, HUE_ORDER

In [None]:
df1 = pd.read_parquet("../reports/publicallstable.parquet")
print(len(df1))
df1["acc_diff"] = df1["Acc1"] - df1["Acc2"]
df1["err1"] = 1 - df1["Acc1"]
df1["err-mean"] = 1 - df1["Acc-mean"]
df1["type"] = "stable"
df1.head()

In [None]:
df2 = pd.read_parquet("../reports/publicsplit.parquet")
print(len(df2))
df2["acc_diff"] = df2["Acc1"] - df2["Acc2"]
df2["err1"] = 1 - df2["Acc1"]
df2["err-mean"] = 1 - df2["Acc-mean"]
df2["type"] = "baseline"
df2.head()

In [None]:
df = pd.concat([df1, df2],axis=0, ).reset_index(drop=True)
df.head()

In [None]:
sns.catplot(data=df, x="type", y="Value", hue="Dataset", col="Metric", row="Model", sharey=False, kind="bar")

In [None]:
sns.catplot(data=df, x="type", y="Acc-mean", hue="Dataset", row="Model", sharey=False, kind="bar", ci="sd")

In [None]:
from pathlib import Path
import json
from omegaconf import OmegaConf

In [None]:
# CHANGE: the runs dict needs the multirun directories to compare the performance of the "stable" model and the baseline
runs = {
    "Public Split": [Path("")],
    "PublicAllStable": [
        Path(""),  # Adam
        Path(""),
        Path(""),
    ],
}

accs = {"Acc": [], "Model": [], "RunType": [], "Dataset": []}
for runtype, rundirs in runs.items():
    for rundir in rundirs:
        for experiment in filter(lambda p: p.is_dir(), rundir.iterdir()):
            cfg = OmegaConf.load(experiment / ".hydra" / "config.yaml")

            evals_path = experiment / "predictions" / "evals.json"
            with evals_path.open("r") as f:
                evals = json.load(f)

            test_accs = [e["test_acc"] for e in evals]
            for test_acc in test_accs:
                accs["Acc"].append(test_acc)
                accs["Model"].append(cfg.model.name)
                accs["RunType"].append(runtype)
                accs["Dataset"].append(cfg.dataset.name)

accs = pd.DataFrame.from_dict(accs)
accs.head()


In [None]:
sns.catplot(data=accs, x="RunType", y="Acc", col="Dataset", row="Model", kind="box", sharey=False)

In [None]:
col_order = ["PI", "NormPI", "True PI", "False PI", "MAE", "SymKL"]

sns.catplot(
    data=df[df.Model == "GAT2017"],
    hue="type",
    x="Acc-std",
    y="Value",
    row="Dataset",
    col="Metric",
    sharey=False,
    sharex=False,
    kind="box",
    hue_order=["baseline", "stable"]
    # col_order=col_order,
)

col_order = ["PI", "NormPI", "True PI", "False PI", "MAE", "SymKL"]

sns.catplot(
    data=df[df.Model == "GAT2017"],
    hue="type",
    x="Acc-mean",
    y="Value",
    row="Dataset",
    col="Metric",
    sharey=False,
    sharex=False,
    kind="box",
    hue_order=["baseline", "stable"]
    # col_order=col_order,
)




In [None]:
df.loc[(df.Model == "GAT2017") & (df.Metric=="PI") & (df.Dataset == "WikiCS")].groupby("type")["Value"].mean()

In [None]:
sns.relplot(
    data=df,
    hue="type",
    x="Acc-mean",
    y="Acc-std",
)



## Paper plot

In [None]:
models = ["GCN2017", "GAT2017"]
for model in models:
    with plt.style.context("seaborn"):
        with update_rcParams({"axes.grid.which": "both",  "lines.linewidth": 1, "lines.markersize": 5}):
            nrows, ncols = 1, 1
            # width, height = set_size(subplots=(nrows, ncols), fraction=1.)
            width, height = set_size(fraction=0.5)
            fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(width, height))
            # fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(height, width))

            pf = df.loc[(df["Metric"] == "PI") & (df["Model"] == model)]
            sns.barplot(
            # sns.boxplot(
                data=pf,
                x="type",
                y="Value",
                hue="Dataset",
                hue_order=HUE_ORDER,
                order=["stable", "baseline"],
                # rows="Model",
                # legend=True,
                ci=None,
                ax=axes,
            )
            # axes.set_yscale("log")
            # axes.set_yticks([0.03, 0.1, 0.5])
            # axes.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
            # axes.set_xscale("log", base=10)
            # xticks = list(sorted(pf["L2eps"].unique()))
            xticks = ["Stable", "Baseline"]
            axes.set_xticklabels(xticks)
            # axes.set_xticklabels([f'{int(np.log10(t))}' for t in xticks])
            # axes.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
            # # axes[0].set_title("Disagreement $d$")
            # axes.set_xlabel("$\log_{10}$ L2 Regularization")
            axes.set_ylabel("Disagreement $d$")
            axes.set_xlabel("Model")
            # # axes[0].set_ylabel("")
            lgd = axes.legend(
                loc="lower right",
                ncol=4, 
                bbox_to_anchor=(1.85, -0.66),
            )
            fig.savefig(f"../reports/stable1_{model}.pdf", bbox_inches="tight")
            
            
            accs.loc[accs.Model == model, "err"] = 1 - accs.loc[accs.Model == model, "Acc"]
            fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(width, height))
            g = sns.barplot(
            # g = sns.boxplot(
                data=accs[accs.Model == model],
                x="RunType",
                # y="err-mean",
                y="err",
                hue="Dataset",
                # order=["stable", "baseline"],
                order=["PublicAllStable", "Public Split"],
                hue_order=HUE_ORDER,
                # style="Model",
                ci=None,
                ax=axes,
                # legend=False,
            )
            axes.get_legend().remove()
            # axes.set_xscale("log", base=10)
            # axes[1].set_yscale("log")
            # xticks = list(sorted(pf["Layers"].unique()))
            # axes.set_xticks(xticks)
            # axes.set_xticklabels(xticks)
            xticks = ["Stable", "Baseline"]
            axes.set_xticklabels(xticks)
            # axes.set_xticklabels([f'{int(np.log10(t))}' for t in xticks])
            # axes.set_yscale("log")
            # axes.set_yticks([0.1, 0.2,0.3,0.4])
            # axes.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
            # axes[1].set_title("Error Rate")
            axes.set_ylabel("Error Rate")
            axes.set_xlabel("Model")
            # axes.set_xlabel("$\log_{10}$ L2 Regularization")
            # axes[1].set_ylabel("")powe
            # fig.tight_layout()
            # fig.set_figheight(15)
            # lgd = axes[1].legend(
            #     loc="lower right",
            #     ncol=4, 
            #     bbox_to_anchor=(0.5, -0.76),
            # )
            fig.savefig(f"../reports/stable2_{model}.pdf", bbox_inches="tight")

In [None]:
models = ["GCN2017", "GAT2017"]
for model in models:
    with plt.style.context("seaborn"):
        with update_rcParams({"axes.grid.which": "both",  "lines.linewidth": 1, "lines.markersize": 5}):
            nrows, ncols = 1, 1
            # width, height = set_size(subplots=(nrows, ncols), fraction=1.)
            width, height = set_size(fraction=0.5)
            fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(width, height))
            # fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(height, width))

            pf = df.loc[(df["Metric"] == "False PI") & (df["Model"] == model)]
            sns.barplot(
            # sns.boxplot(
                data=pf,
                x="type",
                y="Value",
                hue="Dataset",
                hue_order=HUE_ORDER,
                order=["stable", "baseline"],
                # rows="Model",
                # legend=True,
                ci=None,
                ax=axes,
            )
            # axes.set_yscale("log")
            # axes.set_yticks([0.03, 0.1, 0.5])
            # axes.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
            # axes.set_xscale("log", base=10)
            # xticks = list(sorted(pf["L2eps"].unique()))
            xticks = ["Stable", "Baseline"]
            axes.set_xticklabels(xticks)
            # axes.set_xticklabels([f'{int(np.log10(t))}' for t in xticks])
            # axes.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
            # # axes[0].set_title("Disagreement $d$")
            # axes.set_xlabel("$\log_{10}$ L2 Regularization")
            axes.set_ylabel("Disagreement $d$")
            axes.set_xlabel("Model")
            # # axes[0].set_ylabel("")
            lgd = axes.legend(
                loc="lower right",
                ncol=4, 
                bbox_to_anchor=(1.85, -0.66),
            )
            fig.savefig(f"../reports/stable1_{model}_fpi.pdf", bbox_inches="tight")
            
    

In [None]:
models = ["GCN2017", "GAT2017"]
metric_to_name = {
    "PI": "Disagreement $d$",
    "NormPI": "Norm. Dis. $d_{Norm}$",
    "False PI": "False Dis. $d_{False}$",
    "True PI": "True Dis. $d_{True}$",
    "MAE": "MAE",
    "SymKL": "Symmetric KL-Div",
}

for ci, ci_txt in zip(["sd", None], ["sd", "nosd"]):
    for model in df["Model"].unique():
        for metric in df["Metric"].unique():
            with plt.style.context("seaborn"):
                with update_rcParams({"axes.grid.which": "both",  "lines.linewidth": 1, "lines.markersize": 5}):
                    nrows, ncols = 1, 1
                    # width, height = set_size(subplots=(nrows, ncols), fraction=1.)
                    width, height = set_size(fraction=0.5)
                    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(width, height))
                    # fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(height, width))

                    pf = df.loc[(df["Metric"] == metric) & (df["Model"] == model)]
                    sns.barplot(
                    # sns.boxplot(
                        data=pf,
                        x="type",
                        y="Value",
                        hue="Dataset",
                        hue_order=HUE_ORDER,
                        order=["stable", "baseline"],
                        # rows="Model",
                        # legend=True,
                        ci=ci,
                        ax=axes,
                    )
                    # axes.set_yscale("log")
                    # axes.set_yticks([0.03, 0.1, 0.5])
                    # axes.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
                    # axes.set_xscale("log", base=10)
                    # xticks = list(sorted(pf["L2eps"].unique()))
                    xticks = ["Stable", "Baseline"]
                    axes.set_xticklabels(xticks)
                    # axes.set_xticklabels([f'{int(np.log10(t))}' for t in xticks])
                    # axes.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
                    # # axes[0].set_title("Disagreement $d$")
                    # axes.set_xlabel("$\log_{10}$ L2 Regularization")
                    axes.set_ylabel(metric_to_name[metric])
                    axes.set_xlabel("Model")
                    # # axes[0].set_ylabel("")
                    lgd = axes.legend(
                        loc="lower right",
                        ncol=3, 
                        bbox_to_anchor=(1, -0.7),
                    )
                    fig.savefig(f"../reports/appendix/stable1_{model}_{metric}_{ci_txt}.pdf", bbox_inches="tight")
                    
            