## IS between descriptors and embeddings

In [None]:
%cd ..

In [None]:

import os
import numpy as np
from utils.utils_notebook import get_MI_df, plot_cmap, LATEX_FIG_PATH, prerpocess_emb_name
import matplotlib.pyplot as plt
import seaborn as sns

DATASET = "ZINC"
results_dir_list = ["run_4"]
COLUMS_SPLIT = "cond_modes"
METRIC = r"$\overline{\mathcal{I}_S}$"

df = get_MI_df(DATASET, results_dir_list)

df.X = df.X.apply(prerpocess_emb_name)
df.Y = df.Y.apply(prerpocess_emb_name)

In [None]:

def plot_barplot_mi(keys, df, COLUMS_SPLIT):
    df_plot = df[df[COLUMS_SPLIT] == 4]
    df_plot["to_highlight"] = df_plot.X.isin(["GraphMVP", "3D-Infomax"])

    fig, axes = plt.subplots(1, 2, figsize=(3.9, 5))

    df_plot = df_plot[df_plot.Y.isin(keys) & ~df_plot.X.isin(keys)]
    df_plot = df_plot.groupby("X").mean().reset_index()

    palette_no_highlight = sns.color_palette("hls", df_plot.X.nunique(), desat = 0.15)
    palette_highlight = sns.color_palette("hls", df_plot.X.nunique(), desat = 1)
    hue_order = df_plot.sort_values("I(X->Y)/dim").X.unique()
    cmap = {
        model: palette_no_highlight[i] if not model in ["GraphMVP", "3D-Infomax"] else palette_highlight[i] for i, model in enumerate(hue_order)
    }


    sns.barplot(
        data=df_plot.sort_values("I(X->Y)/dim"),
        y="X",
        x="I(X->Y)/dim",
        hue="X",
        ax=axes[0],
        legend=False,
        palette=cmap,
    )

    hue_order = df_plot.sort_values("I(Y->X)/dim").X.unique()
    cmap = {
        model: palette_no_highlight[i] if not model in ["GraphMVP", "3D-Infomax"] else palette_highlight[i] for i, model in enumerate(hue_order)
    }
    sns.barplot(
        data=df_plot.sort_values("I(Y->X)/dim"),
        y="X",
        x="I(Y->X)/dim",
        hue="X",
        ax=axes[1],
        legend=False,
        palette=cmap,
    )

    axes[0].set_ylabel("")
    axes[0].set_xlabel("Ability to predict denoising \n3D models")
    axes[1].set_ylabel("")
    axes[1].set_xlabel("Ability to be predicted \nby denoising 3D models")
    axes[1].set_xticklabels([])
    axes[0].set_xticklabels([])
    fig.tight_layout()

In [None]:
keys = ["3D-denoising", "FRAD_QM9"]
plot_barplot_mi(keys, df, COLUMS_SPLIT)

plt.savefig(
    f"{LATEX_FIG_PATH}/molecule/denoising_3D_MI.pdf",
    format="pdf",
    bbox_inches="tight",
)

In [None]:
from scipy.cluster.hierarchy import linkage
%matplotlib inline
df_plot = df

fig,axes = plt.subplots(1, 2, figsize=(2.8, 5.5), sharey=True)

#Barplot of median incomint/outgoing emir
order = df_plot.groupby("X").median().sort_values("I(X->Y)/dim").index
sns.barplot(
    data=df_plot,
    y="X",
    x="I(X->Y)/dim",
    hue="X",
    ax=axes[0],
    palette="coolwarm",
    order=order,
    hue_order=order,
    capsize=.2,
    err_kws={"linewidth": 0.5},
    estimator=np.median
)

sns.barplot(
    data=df_plot.sort_values("I(X->Y)/dim"),
    y="X",
    x="I(Y->X)/dim",
    hue="X",
    ax=axes[1],
    palette="coolwarm",
    order=order,
    hue_order=order,
    capsize=.2,
    err_kws={"linewidth": 0.5},
    estimator=np.median
)

axes[0].set_ylabel("Models Z")

axes[0].set_xlabel(r"$\overline{\mathcal{I}_S}}$" + r"$(Z\rightarrow U)$")
axes[1].set_xlabel(r"$\overline{\mathcal{I}_S}}$"+ r"$(U\rightarrow Z)$")
plt.draw()
# Put 3D models in bold red
for i, label in enumerate(axes[0].get_yticklabels()):
    if label.get_text() == 'FRAD_QM9' or label.get_text() == '3D-denoising':
        label.set_weight("bold")
        label.set_color("red")


plt.savefig(
    f"{LATEX_FIG_PATH}/molecule/barplot_MI-3D.pdf",
    format="pdf",
    bbox_inches="tight",
)