In [None]:
%cd ..

In [None]:
import os
import numpy as np
from molecule.utils.utils_notebook import get_MI_df, get_ranked_df, LATEX_FIG_PATH, get_DTI_rank_df, process_dataset_name, prerpocess_emb_name
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import clear_output

METRIC = r"$\overline{\mathcal{I}_S}$"
DATASET = "ZINC"
results_dir_list = ["run_4"]
COLUMS_SPLIT = "cond_modes"
df = get_MI_df(DATASET, results_dir_list)
df_downs = get_ranked_df(df, path="results/TDC_ADMET_SCAFF.csv")
df_downs_pre_task_type = get_ranked_df(df, split_on="task_type", path="results/TDC_ADMET_SCAFF.csv")
clear_output()

# Downstream Eval

In [None]:
df_downs = get_ranked_df(df, path="results/TDC_ADMET_SCAFF.csv")
clear_output()
%matplotlib inline

fig, ax= plt.subplots(1,1,figsize=(1.2,3.5))

df_plot = df_downs
df_plot["embedder"] = df_plot.embedder.apply(prerpocess_emb_name)
sns.scatterplot(data=df_plot.groupby("embedder").mean().sort_values("information"), y="embedder", x="global_meanrank_metric", hue="embedder", legend=False)

corr = df_plot[["global_meanrank_metric", "meanrank_information"]].corr().iloc[0,1]
kendall = df_plot[["global_meanrank_metric", "meanrank_information"]].corr("kendall").iloc[0,1]
#plt.text(0.85, 0.15, r"$\varrho$: " +  f"{corr:.2f}\n"+r"$\tau $: "+f"{kendall:.2f}", horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes)
ticks = df_plot.groupby("embedder").mean().sort_values("information").index

ax.vlines(df_plot[df_plot.embedder == "Not-trained"].global_meanrank_metric.iloc[0], 0, len(ticks)-1, color="r", linestyle="--", alpha = 0.8,)
ax.hlines(df_plot[df_plot.embedder == "Not-trained"].embedder.iloc[0], df_plot.global_meanrank_metric.min(), df_plot.global_meanrank_metric.max(), color="r", linestyle="--", alpha = .8)

plt.xlabel("ADMET rank")
plt.ylabel("")
plt.draw()

# Set Not-trained in bold


#Put one y tick out of two on each side of the plot
y2 = ax.twinx()
y2.set_yticks(ax.get_yticks())
y2.set_yticklabels(ticks)
y2.set_yticks(ax.get_yticks()[1::2])
y2.set_ylim(ax.get_ylim())


fontsize = 10
for i, tick in enumerate(y2.get_yticklabels()):
    if i == 4:
        tick.set_fontweight("bold")
        tick.set_fontsize(fontsize)
    else:
        tick.set_fontsize(fontsize)
        tick.set_fontweight("normal")
ax.set_yticks(ax.get_yticks()[::2])
for tick in ax.get_yticklabels():
    tick.set_fontsize(fontsize)
    tick.set_fontweight("normal")


plt.draw()
plt.tight_layout()
#plt.yticks(rotation=90)
plt.savefig(f"{LATEX_FIG_PATH}/molecule/mearnak_detailed_global_scaff.pdf", format = "pdf", bbox_inches = 'tight')

In [None]:
avg_results = df_downs.groupby("embedder").mean()[["global_meanrank_metric", "meanrank_information"]].reset_index().rename(
    columns={
        "global_meanrank_metric": "ADMET",
        "meanrank_information": METRIC,
        "embedder": "model"
    }
)
avg_results.model = avg_results.model.apply(prerpocess_emb_name)

clear_output()

In [None]:
# make mosaic map
%matplotlib inline

def plot_glob_mosaic(
        avg_results,
        METRIC=METRIC,
        MODEL_TO_ANNOATE = set(["3D-Infomax", "MolBert", "BertMTR-5M", "GraphCL", "Not-trained"]),
        MODEL_TO_ANNOATE_left = set(),
        FIGSIZE=5,
        fontsize=8
):
    mosaic_map = [
        ['A',]
      ]

    mosaic_to_task_map = {
        'A': 'ADMET',
    }

    mosaic_map = np.array(mosaic_map)
    mosaic_map = np.vectorize(lambda x: mosaic_to_task_map[x])(mosaic_map)


    fig, ax = plt.subplot_mosaic(mosaic_map, figsize=(FIGSIZE, FIGSIZE*3/4))
    METRICS = mosaic_to_task_map.values()
    for i, metric in enumerate(METRICS):
        if metric == "ADMET":
            s = 100
        else:
            s = 50
        ax[metric] = sns.regplot(data=avg_results, y=METRIC, x=metric, ax=ax[metric], scatter=False)
        ax[metric] = sns.scatterplot(data=avg_results.sort_values(METRIC), y=METRIC, x=metric, ax=ax[metric], legend=False, hue = 'model', style='model', s=s)


        # annotate model names
        if metric == 'ADMET':
            for i, row in avg_results.iterrows():
                xy = (row[metric], row[METRIC])
                if row['model'] in MODEL_TO_ANNOATE:
                    xytext = (20.5,4)
                    ax[metric].annotate(
                        row['model'],
                        xy,
                        fontsize=fontsize,
                        va='center',
                        ha='left',
                        textcoords='offset points',
                        xytext=xytext,
                        arrowprops=dict(
                            facecolor='black',
                            color='black',
                            arrowstyle='->',
                            connectionstyle='arc3,rad=0.2'
                        )
                    )
                elif row["model"] in MODEL_TO_ANNOATE_left:
                    xytext = (-65.5,5.5)
                    ax[metric].annotate(
                        row['model'],
                        xy,
                        fontsize=fontsize,
                        va='center',
                        ha='left',
                        textcoords='offset points',
                        xytext=xytext,
                        arrowprops=dict(
                            facecolor='black',
                            color='black',
                            arrowstyle='->',
                            connectionstyle='arc3,rad=0.2'
                        )
                    )

        # annotate correllation
        corr = avg_results[[METRIC, metric]].corr(method="spearman").iloc[0, 1]
        kendall = avg_results[[METRIC, metric]].corr(method="kendall").iloc[0, 1]
        pearson = avg_results[[METRIC, metric]].corr(method="pearson").iloc[0, 1]
        if metric != 'ADMET':
            ax[metric].annotate(
                r"$\varrho_s$" + f": {corr:.2f}\n" + r" $\tau$" + f": {kendall:.2f}",
                (0.55, 0.05),
                xycoords='axes fraction',
                fontsize=8,
                bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3'),
            )
        else:
            ax[metric].annotate(
                r"$\varrho_s$" + f": {corr:.2f}\n" + r" $\tau$" + f": {kendall:.2f}",
                (0.75, 0.1),
                xycoords='axes fraction',
                fontsize=12,
                bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.5'),
            )

        if metric == 'ADMET':
            ax[metric].set_title(metric, fontweight='bold')
            ax[metric].set_ylabel(METRIC + "  " + r"$(rank\downarrow)$")
            ax[metric].set_xlabel('Downstream tasks  $(rank\downarrow)$')
            ax[metric].set_xlim(5.5,22.7)
            # set xticks as ints
            ax[metric].set_xticks(range(6, 24, 3))
            ax[metric].set_yticks(range(0, 30, 5))
        else:
            ax[metric].set_title(metric.replace('Average', '').replace("(", "\n("), fontsize=10,)
            ax[metric].set_xlabel('')
            ax[metric].set_ylabel('')
            ax[metric].set_xticks([])
            ax[metric].set_yticks([])


    fig.tight_layout()

    plt.savefig(f"{LATEX_FIG_PATH}/molecule/meanrank_task_size_allfig_scaff.pdf", format = "pdf", bbox_inches = 'tight')


In [None]:
plot_glob_mosaic(
    avg_results,
    METRIC=METRIC,
    FIGSIZE=4.5,
    MODEL_TO_ANNOATE = set([
        "MolR_tag", "FRAD_QM9",
    ]),
    MODEL_TO_ANNOATE_left = set(["Not-trained", "ContextPred", "GraphMVP",]),
    fontsize=9
)

In [None]:
#same fig by separating on task_type
%matplotlib inline
df_downs = get_ranked_df(df, split_on="task_category", path="results/TDC_ADMET_SCAFF.csv")
FIGSIZE = 2
clear_output()
fig,axes = plt.subplots(1,df_downs.task_category.nunique(), figsize=(FIGSIZE*df_downs.task_category.nunique(),FIGSIZE*1.2), sharey=True, )

for i,task_category in enumerate(["Absorption", "Distribution", "Metabolism", "Excretion", "Tox"]):
    df_tmp = df_downs[df_downs.task_category == task_category].groupby("embedder").mean()
    sns.scatterplot(data=df_tmp, x="global_meanrank_metric", y="meanrank_information", hue="embedder", ax=axes[i], legend=False, style="embedder", s=75)
    sns.regplot(data=df_tmp, x="global_meanrank_metric", y="meanrank_information", ax=axes[i], scatter=False, color="blue", line_kws = {"alpha":0.2})

    # Display the correlation coefficient
    pearson = df_tmp[["global_meanrank_metric","meanrank_information"]].corr("pearson").iloc[0,1]
    corr = df_tmp[["global_meanrank_metric","meanrank_information"]].corr("spearman").iloc[0,1]
    kendall = df_tmp[["global_meanrank_metric","meanrank_information"]].corr("kendall").iloc[0,1]
    axes[i].text(
        0.7,
        0.18,
        r"$\varrho_s$: " +  f"{corr:.2f}\n"+r"$\tau $: "+f"{kendall:.2f}",
        horizontalalignment='center',
        verticalalignment='center',
        transform=axes[i].transAxes
    )

    axes[i].set_title(f"{task_category}")
    axes[i].set_ylabel("")
    axes[i].set_xlabel("")
    axes[i].set_xticks([])
    axes[i].set_yticks([])

fig.supxlabel(f"Downstream tasks $(rank)$")
fig.supylabel(f"{METRIC}  $(rank)$",)

fig.tight_layout()
plt.savefig(f"{LATEX_FIG_PATH}/molecule/meanrank_task_category_scaff.pdf", format = "pdf", bbox_inches = 'tight')

In [None]:
#same fig by separating on task_type
%matplotlib inline
df_downs = get_ranked_df(df, split_on="task_category", path="results/TDC_ADMET_SCAFF.csv")
FIGSIZE = 2
clear_output()
fig,axes = plt.subplots(1,df_downs.task_category.nunique(), figsize=(FIGSIZE*df_downs.task_category.nunique(),FIGSIZE*1.2), sharey=True, )

for i,task_category in enumerate(["Absorption", "Distribution", "Metabolism", "Excretion", "Tox"]):
    df_tmp = df_downs[df_downs.task_category == task_category].groupby("embedder").mean()
    sns.scatterplot(data=df_tmp, x="global_meanrank_metric", y="meanrank_information", hue="embedder", ax=axes[i], legend=False, style="embedder", s=75)
    sns.regplot(data=df_tmp, x="global_meanrank_metric", y="meanrank_information", ax=axes[i], scatter=False, color="blue", line_kws = {"alpha":0.2})

    # Display the correlation coefficient
    pearson = df_tmp[["global_meanrank_metric","meanrank_information"]].corr("pearson").iloc[0,1]
    corr = df_tmp[["global_meanrank_metric","meanrank_information"]].corr("spearman").iloc[0,1]
    kendall = df_tmp[["global_meanrank_metric","meanrank_information"]].corr("kendall").iloc[0,1]
    axes[i].text(
        0.7,
        0.18,
        r"$\varrho_s$: " +  f"{corr:.2f}\n"+r"$\tau $: "+f"{kendall:.2f}",
        horizontalalignment='center',
        verticalalignment='center',
        transform=axes[i].transAxes
    )

    axes[i].set_title(f"{task_category}")
    axes[i].set_ylabel("")
    axes[i].set_xlabel("")
    axes[i].set_xticks([])
    axes[i].set_yticks([])

fig.supxlabel(f"Downstream tasks $(rank)$")
fig.supylabel(f"{METRIC}  $(rank)$",)

fig.tight_layout()
plt.savefig(f"{LATEX_FIG_PATH}/molecule/meanrank_task_category_scaff.pdf", format = "pdf", bbox_inches = 'tight')

# Comprehensive Results

In [None]:
%matplotlib inline
df_downs = get_ranked_df(df, path="results/TDC_ADMET_SCAFF.csv")

n_rows = 5
n_cols = int(np.ceil(df_downs.dataset.nunique() / n_rows))
import pandas as pd



FIGSIZE = 1.8
fig, axes = plt.subplots(n_rows,n_cols, figsize=(FIGSIZE*n_cols,FIGSIZE*n_rows*1.2))
axes = axes.flatten()
df_corr = pd.DataFrame(columns=["dataset", "kendall", "pearson", "spearman"])

for i,dataset in enumerate(df_downs.sort_values("task_size").dataset.unique()):
    df_tmp = df_downs[(df_downs.dataset == dataset)]
    df_tmp.embedder = df_tmp.embedder.apply(prerpocess_emb_name)
    # compute ranking for roc
    df_tmp = df_tmp.groupby("embedder").median()
    sns.scatterplot(data=df_tmp, x="metric", y="meanrank_information", hue="embedder", ax=axes[i], legend=False)
    # add linear regression
    sns.regplot(data=df_tmp, x="metric", y="meanrank_information", ax=axes[i], scatter=False, color="blue", line_kws = {"alpha":0.2})
    axes[i].set_title(
            process_dataset_name(dataset)
    )
    # Remove the x and y axis labels
    axes[i].set_ylabel("")
    axes[i].set_xlabel("")
    # Remove the xticks
    #axes[i].set_xticks([])
    #axes[i].set_yticks([])
    pearson = df_tmp[["metric","meanrank_information"]].corr("pearson").iloc[0,1]
    corr = df_tmp[["metric","meanrank_information"]].corr("spearman").iloc[0,1]
    corr_kendall = df_tmp[["metric","meanrank_information"]].corr("kendall").iloc[0,1]

    df_corr = df_corr.append({"dataset":dataset, "kendall":-corr_kendall, "pearson":-pearson, "spearman":-corr}, ignore_index=True)
# Remove the empty subplots
for i in range(len(df_downs.dataset.unique()), len(axes)):
    fig.delaxes(axes[i])

fig.tight_layout()



df_downs = df_downs.join(df_corr.set_index("dataset"), on="dataset")

#plt.savefig(f"{LATEX_FIG_PATH}/molecule/all_TDC_results_scaff.pdf", format = "pdf", bbox_inches = 'tight')