In [None]:
%cd ..
import os
import numpy as np
import pandas as pd

from utils.utils_notebook import get_MI_df, get_DTI_rank_df, LATEX_FIG_PATH
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
METRIC = r"$\overline{\mathcal{I}_S}$"
warnings.filterwarnings("ignore")

# Downstream Eval

In [None]:
df = get_MI_df("ZINC", results_dir_list=["run_4"])
df_downs = get_DTI_rank_df(df, metric=f"clustering_l2_4",dataset="KIBA",order="ascending")
df


In [None]:
%matplotlib inline


def plot_corr(df, REMOVE_MODELS=[], FIGSIZE = 3, title="", DATASET="Kiba"):
    fig, axes = plt.subplots(2,2, figsize=(FIGSIZE*2,FIGSIZE*2), sharex=True, sharey=False)
    axes = axes.flatten()
    key = "clustering_l2_"


    for i,n_clus in enumerate([1,2,4,8]):
        if "l2" in key:
            end = "(reg)"
            df_downs = get_DTI_rank_df(df, metric=f"{key}{n_clus}",dataset=DATASET,order="ascending")
        else:
            end = "(clas)"
            df_downs = get_DTI_rank_df(df, metric=f"{key}{n_clus}")
        df_downs = df_downs[~df_downs.embedder.isin(REMOVE_MODELS)]
        df_tmp = df_downs.groupby("embedder").mean()
        sns.scatterplot(data=df_tmp, y="meanrank_metric", x="meanrank_information", hue="embedder", ax=axes[i], legend=False, style="embedder",)
        sns.regplot(data=df_tmp, y="meanrank_metric", x="meanrank_information", ax=axes[i], scatter=False, color="blue", line_kws = {"alpha":0.2})

        # Display the correlation coefficient
        corr = df_tmp[["meanrank_metric", "meanrank_information"]].corr("spearman").iloc[0,1]
        corr_kendall = df_tmp[["meanrank_metric", "meanrank_information"]].corr("kendall").iloc[0,1]
        corr_p = -df_tmp[[f"{key}{n_clus}", "information"]].corr("pearson").iloc[0,1]
        axes[i].text(
            0.8,
            0.2,
             f"\nR: {corr_p:.2f}\n" + r"$\varrho_s $" + f": {corr:.2f}\n " +r"$\tau $: "+f"{corr_kendall:.2f}",
            horizontalalignment='center',
            verticalalignment='center',
            transform=axes[i].transAxes
        )
        axes[i].set_title("$n_{neighbors}$ = "+f"{n_clus}")

        axes[i].set_xlabel("")
        axes[i].set_ylabel("")

        #remove x/yticks
        axes[i].set_xticks([])
        axes[i].set_yticks([])
    fig.supylabel(f"Local agreement rank")
    fig.supxlabel(f"{METRIC} rank")
    fig.suptitle(title)


In [None]:
DATASET="KIBA"
for DATASET in ["KIBA", "BindingDB_Kd", "BindingDB_Ki", "BindingDB_IC50"]:
    plot_corr(df, title=f"{DATASET}", FIGSIZE=2.3, DATASET=DATASET)
    plt.savefig(f"{LATEX_FIG_PATH}/molecule/DTI_all_res_{DATASET}.pdf", format = "pdf", bbox_inches = 'tight')
    plt.show()

In [None]:
import pandas as pd
import numpy as np
from autorank import autorank

def compute_ranking(df_downs, dataset="Overall", type_metric="reg", n_neighb=-1):
    df_downs_linearized = []
    for col in df_downs.columns:
        if "clustering" in col:
            df_tmp = df_downs[["embedder", "dataset", "target", col]].rename(columns={col:"metric"})
            df_tmp["n_neighb"] = int(col.split("_")[-1])
            if "l2" in col:
                df_tmp["type"] = "reg"
                df_tmp.metric = 1-df_tmp.metric
            else:
                df_tmp["type"] = "cls"
            df_downs_linearized.append(df_tmp)
    df_downs_linearized = pd.concat(df_downs_linearized)

    if type_metric!="all":
        df_downs_linearized = df_downs_linearized[df_downs_linearized["type"] == type_metric]
    if n_neighb != -1:
        df_downs_linearized = df_downs_linearized[df_downs_linearized["n_neighb"] == n_neighb]

    df_to_rank = df_downs_linearized.pivot_table(index=["n_neighb", "type",  "dataset", "target"], columns="embedder", values="metric")

    res = autorank(
        df_to_rank,
        alpha=0.05,
        verbose=False,
        force_mode="nonparametric"
    ).rankdf.meanrank.to_frame()
    res = res.rename(columns={"meanrank":dataset})
    return res





In [None]:
df_plot = None

DATASETS = ["BindingDB_IC50","BindingDB_Ki", "BindingDB_Kd","KIBA"]

for dataset in DATASETS:
    df_downs = get_DTI_rank_df(df, order="ascending", dataset=dataset)
    res = compute_ranking(df_downs, dataset)
    if df_plot is None:
        df_plot = res
    else:
        df_plot = df_plot.join(res)


df_downs =  pd.concat(
    [
        get_DTI_rank_df(
            df, order="ascending", dataset=dataset
        ) for dataset in DATASETS
    ]
)
res = compute_ranking(df_downs)
df_plot = df_plot.join(res)

df_plot = df_plot.join(df_downs.groupby("embedder").mean().rename(columns={"meanrank_information":METRIC})[[METRIC]])
df_plot

In [None]:
avg_results = df_plot.reset_index().rename(columns={"embedder":"model"})

def prerpocess_emb_name(x):
    return x.replace(
        "DenoisingPretrainingPQCMv4",
        "3D-denoising"
    ).replace(
        "Chem",
        ""
    ).replace(
        "ThreeDInfomax",
        "3D-Infomax"
    ).replace(
        "_OGB",
        ""
    )

avg_results.model = avg_results.model.apply(prerpocess_emb_name)
avg_results

In [None]:
df_downs

In [None]:
df_downs_all_concat = pd.concat(
    [
        df_downs[[f"clustering_l2_{n}", "dataset", "target", "information"]].rename(columns={f"clustering_l2_{n}":f"value"}) for n in [1]
    ], ignore_index=True
)
df_downs_all_concat

In [None]:
# make mosaic map
%matplotlib inline

def plot_glob_mosaic(avg_results, METRIC=METRIC, MODEL_TO_ANNOATE = set(["3D-Infomax", "InfoGraph", "MolBert", "BertMTR-5M", "GraphCL", "Not-trained"]), MODEL_TO_ANNOATE_left = set(), FIGSIZE=5):
    mosaic_map = [
        ['A', 'A', 'B', 'C'],
        ['A', 'A', 'D', 'E'],
    ]

    mosaic_to_task_map = {
        'A': 'Overall',
        'B': 'BindingDB_Kd',
        'C': 'BindingDB_Ki',
        'D': 'BindingDB_IC50',
        'E': 'KIBA',
    }

    mosaic_map = np.array(mosaic_map)
    mosaic_map = np.vectorize(lambda x: mosaic_to_task_map[x])(mosaic_map)



    fig, ax = plt.subplot_mosaic(mosaic_map, figsize=(FIGSIZE, FIGSIZE * 1.7/3), gridspec_kw={"width_ratios": [1,1, 1, 1],})
    METRICS = mosaic_to_task_map.values()
    for i, metric in enumerate(METRICS):
        if metric == "Overall":
            s = 100
        else:
            s = 50
        ax[metric] = sns.regplot(data=avg_results, y=METRIC, x=metric, ax=ax[metric], scatter=False)
        ax[metric] = sns.scatterplot(data=avg_results, y=METRIC, x=metric, ax=ax[metric], legend=False, hue = 'model', style='model', s=s)


        # annotate model names
        if metric == 'Overall':
            for i, row in avg_results.iterrows():
                xy = (row[metric], row[METRIC])
                if row['model'] in MODEL_TO_ANNOATE:
                    xytext = (9.5,-9.5)
                    ax[metric].annotate(
                        row['model'],
                        xy,
                        fontsize=10,
                        va='center',
                        ha='left',
                        textcoords='offset points',
                        xytext=xytext,
                        arrowprops=dict(
                            facecolor='black',
                            color='black',
                            arrowstyle='->',
                            connectionstyle='arc3,rad=0.2'
                        )
                    )
                elif row["model"] in MODEL_TO_ANNOATE_left:
                    xytext = (-65.5,10.5)
                    ax[metric].annotate(
                        row['model'],
                        xy,
                        fontsize=10,
                        va='center',
                        ha='left',
                        textcoords='offset points',
                        xytext=xytext,
                        arrowprops=dict(
                            facecolor='black',
                            color='black',
                            arrowstyle='->',
                            connectionstyle='arc3,rad=0.2'
                        )
                    )

        # annotate correllation
        corr = avg_results[[METRIC, metric]].corr(method="spearman").iloc[0, 1]
        kendall = avg_results[[METRIC, metric]].corr(method="kendall").iloc[0, 1]
        if metric != 'Overall':
            pearson = df_downs_all_concat[df_downs_all_concat.dataset == metric][["value", "information"]].corr("pearson").iloc[0, 1]
        else:
            pearson = df_downs_all_concat[["value", "information"]].corr("pearson").iloc[0, 1]
        if metric != 'Overall':
            ax[metric].annotate(
                r"$\varrho_s$" + f": {corr:.2f}\n" + r" $\tau$" + f": {kendall:.2f}",
                (0.55, 0.05),
                xycoords='axes fraction',
                fontsize=8,
                bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3'),
            )
        else:
            ax[metric].annotate(
                r"$\varrho_s$" + f": {corr:.2f}\n" + r" $\tau$" + f": {kendall:.2f}",
                (0.75, 0.1),
                xycoords='axes fraction',
                fontsize=12,
                bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.5'),
            )

        if metric == 'Overall':
            ax[metric].set_title(metric, fontweight='bold')
            ax[metric].set_ylabel(METRIC + "  " + r"$(rank\downarrow)$")
            ax[metric].set_xlabel('Downstream tasks  $(rank\downarrow)$')
        else:
            ax[metric].set_title(metric.replace('Average', '').replace("(", "\n("), fontsize=10,)
            ax[metric].set_xlabel('')
            ax[metric].set_ylabel('')
            ax[metric].set_xticks([])
            ax[metric].set_yticks([])


    fig.tight_layout()

    plt.savefig(f"{LATEX_FIG_PATH}/molecule/meanrank_DTI_all.pdf", format = "pdf", bbox_inches = 'tight')


In [None]:

plot_glob_mosaic(
    avg_results,
    METRIC=METRIC,
    MODEL_TO_ANNOATE = set(["3D-Infomax", "GPT-1.2B", "MolBert", "BertMTR-5M",]),
    MODEL_TO_ANNOATE_left = set(["GraphCL", "Not-trained", "InfoGraph"]),
    FIGSIZE=7.5
)