# Knife MI analysis

In [None]:
import os

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm as tqdm
import numpy as np
import json
import sklearn
from sklearn.decomposition import PCA

from utils import MolecularFeatureExtractor
from models.model_paths import get_model_path

from main import GROUPED_MODELS

from utils_notebook import plot_embeddings, MODELS, MODELS_PATH, get_loss_df,get_MI_df

In [None]:
# plot_embeddings("ClinTox")

In [None]:
DATASET = "ZINC"
results_dir_list = ["run_2-1layer", "run_2", "run_2_3", "run_2_4"]

In [None]:
COLUMS_SPLIT = "ff_layers"

In [None]:
full_df_loss_marg, full_df_loss_cond = get_loss_df(DATASET, results_dir_list)

In [None]:
n_rows = np.ceil(full_df_loss_marg.X.nunique()/5).astype(int)+1
n_cols = 5
fig, axes = plt.subplots(n_rows,n_cols,figsize=(n_cols*4,4*n_rows))
axes = axes.flatten()
for i, model in enumerate(tqdm(full_df_loss_marg.X.unique())):
    df_tmp = full_df_loss_marg[full_df_loss_marg.X == model]
    sns.lineplot(data=df_tmp, x="epoch", y="marg_ent", hue=COLUMS_SPLIT, ax=axes[i], estimator=None, errorbar= None, n_boot=0, legend=False)
    axes[i].set_title(model)
    axes[i].set_xlabel("")
    axes[i].set_ylabel("Marginal entropy")

In [None]:
full_df_loss_cond.epoch = full_df_loss_cond.epoch.astype(int)
full_df_loss_cond = full_df_loss_cond.groupby(["X", "Y", "cond_modes", "marg_modes", "ff_hidden_dim", "ff_layers", "epoch"]).mean().reset_index()

In [None]:

for ff_layers in full_df_loss_cond.ff_layers.unique():
    n_cols = 4
    n_rows = np.ceil(full_df_loss_cond.Y.nunique()/n_cols).astype(int)

    fig, axes = plt.subplots(n_rows,n_cols,figsize=(n_cols*4,4*n_rows))
    axes = axes.flatten()
    for i, model in enumerate(tqdm(full_df_loss_cond.Y.unique())):
        df_tmp = full_df_loss_cond[
            (full_df_loss_cond.Y == model) & (full_df_loss_cond.ff_layers == ff_layers)
        ]
        sns.lineplot(data=df_tmp, x="epoch", y="cond_ent", hue="X", ax=axes[i], errorbar= None, n_boot=0, legend=False)
        axes[i].set_title(model)
        axes[i].set_xlabel("")
        axes[i].set_ylabel("H(Y|X)")
    plt.show()

In [None]:
df_tmp

## MI between descriptors and embeddings

In [None]:
import os
import numpy as np

df = get_MI_df(DATASET, results_dir_list)

#df =df[df.Y.isin(DESCRIPTORS)]
df.groupby("X").X_dim.mean()


## Clustermap

In [None]:
df["I(X->Y)/dim"] = df["I(X->Y)"]/df["Y_dim"]

df_grouped = df[df.X!=df.Y].groupby(["Y", "cond_modes", "marg_modes", "ff_hidden_dim", "ff_layers"]).mean().reset_index()

df["I(X->Y)_norm"] = df.apply(
    lambda x: x["I(X->Y)"] / df_grouped[
        (df_grouped.Y == x.Y) & (df_grouped.cond_modes == x.cond_modes) & (df_grouped.marg_modes == x.marg_modes) & (df_grouped.ff_hidden_dim == x.ff_hidden_dim) & (df_grouped.ff_layers == x.ff_layers)
    ]["I(X->Y)"].values[0], axis=1
)

df["I(X->Y)/dim_norm"] = df.apply(
    lambda x: x["I(X->Y)/dim"] / df_grouped[
        (df_grouped.Y == x.Y) & (df_grouped.cond_modes == x.cond_modes) & (df_grouped.marg_modes == x.marg_modes) & (df_grouped.ff_hidden_dim == x.ff_hidden_dim) & (df_grouped.ff_layers == x.ff_layers)
    ]["I(X->Y)/dim"].values[0], axis=1
)

df = df.fillna(0)
df

In [None]:
from scipy.cluster.hierarchy import linkage

def plot_cmap(df, keys, cmap = "copper", vmin = None, vmax = None, center = None, values=False, same_linkage=True, title=""):
    if vmax is None:
        vmax = [None]*len(keys)
    if vmin is None:
        vmin = [None]*len(keys)
    for i, key in enumerate(keys):
        df_pivot = df.pivot_table(index="X", columns="Y", values=key, aggfunc="mean")
        if same_linkage:
            link = linkage(df_pivot, method="ward")
        else:
            link = None
        cluster = sns.clustermap(
            df_pivot, row_linkage=link, col_linkage=link,
            cmap=cmap, figsize=(8,8), vmin=vmin[i], vmax=vmax[i], center=center, annot=values
        )
        cluster.savefig("fig/cluster_{}.png".format(i))
        plt.clf()

    import matplotlib.image as mpimg
    fig, axes = plt.subplots(1,len(keys), figsize=(8*len(keys),8))
    for i, key in enumerate(keys):
        axes[i].imshow(mpimg.imread("fig/cluster_{}.png".format(i)))
        axes[i].axis("off")
        axes[i].set_title(key)

    fig.suptitle(title)
    plt.show()

In [None]:
%matplotlib inline

plot_cmap(
    df[df[COLUMS_SPLIT]==1],
    ["I(X->Y)", "I(X->Y)/dim"],
    cmap="viridis",
    vmin=[None,None],
    vmax=[None,None],
    same_linkage=True,
    title="Clustermap of the mutual information between models - {} : {}".format(COLUMS_SPLIT, 1)
)

plot_cmap(
    df[df[COLUMS_SPLIT]==2],
    ["I(X->Y)", "I(X->Y)/dim"],
    cmap="viridis",
    vmin=[None,None],
    vmax=[None,None],
    same_linkage=True,
     title="Clustermap of the mutual information between models - {} : {}".format(COLUMS_SPLIT, 2)
)

plot_cmap(
    df[df[COLUMS_SPLIT]==3],
    ["I(X->Y)", "I(X->Y)/dim"],
    cmap="viridis",
    vmin=[None,None],
    vmax=[None,None],
    same_linkage=True,
     title="Clustermap of the mutual information between models - {} : {}".format(COLUMS_SPLIT, 3)
)

plot_cmap(
    df[df[COLUMS_SPLIT]==4],
    ["I(X->Y)", "I(X->Y)/dim"],
    cmap="viridis",
    vmin=[None,None],
    vmax=[None,None],
    same_linkage=True,
        title="Clustermap of the mutual information between models - {} : {}".format(COLUMS_SPLIT, 4)
)

In [None]:
df_diff_14 = ((df[df[COLUMS_SPLIT]==4].set_index(["X","Y"])[["I(X->Y)", "I(X->Y)/dim", "I(X->Y)/dim_norm"]] - df[df[COLUMS_SPLIT]==1].set_index(["X","Y"])[["I(X->Y)", "I(X->Y)/dim", "I(X->Y)/dim_norm"]])/ df.groupby(["X","Y"])[["I(X->Y)", "I(X->Y)/dim", "I(X->Y)/dim_norm"]].mean()).reset_index()

df_diff_12 = ((df[df[COLUMS_SPLIT]==2].set_index(["X","Y"])[["I(X->Y)", "I(X->Y)/dim", "I(X->Y)/dim_norm"]] - df[df[COLUMS_SPLIT]==1].set_index(["X","Y"])[["I(X->Y)", "I(X->Y)/dim", "I(X->Y)/dim_norm"]])/ df.groupby(["X","Y"])[["I(X->Y)", "I(X->Y)/dim", "I(X->Y)/dim_norm"]].mean()).reset_index()

In [None]:
plot_cmap(
    df_diff_12,
    ["I(X->Y)/dim", "I(X->Y)/dim_norm"],
    cmap="vlag",
    vmin=[None,None, None],
    vmax=[None,None, None],
    same_linkage=True,
    center=0
)

plot_cmap(
    df_diff_14,
    ["I(X->Y)/dim", "I(X->Y)/dim_norm"],
    cmap="vlag",
    vmin=[None,None, None],
    vmax=[None,None, None],
    same_linkage=True,
    center=0
)

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patheffects as patheffects

sns.set_style("whitegrid")


In [None]:
from netgraph import Graph, InteractiveGraph
from networkx.algorithms.community import girvan_newman, modularity_max, louvain_communities

def plot_com(df_in, weight_col = "I(X->Y)/dim", cmap="crest", edge_cmap="flare", min_alpha = 0.0,
    max_alpha = 1.0, min_edge_width = 0.0, max_edge_width = 1, edge_pow=2, node_layout="community", edge_layout="bundled", com_resolution=1.1, figsize=10, clip_min_values_width = 0, clip_min_values_alpha = 0, com_pad_by = 0.001, fontsize = 15, undirected=False, sparsity=1,):

    df = df_in[df_in.X != df_in.Y].copy()
    weight_col_norm = f"{weight_col}_norm"

    table = df.pivot_table(index="X", columns="Y", values=weight_col, aggfunc="mean")
    G= nx.from_pandas_adjacency(table, create_using=nx.DiGraph)
    G.remove_edges_from(nx.selfloop_edges(G))

    if sparsity<1:
        df[weight_col_norm]  *= df[weight_col_norm] > df.groupby("X")[weight_col_norm].transform(lambda x: np.percentile(x.unique(), sparsity*100))

    table_un = df.pivot_table(index="X", columns="Y", values=weight_col_norm, aggfunc="mean")
    G_un= nx.from_pandas_adjacency(table_un, create_using=nx.DiGraph)

    G_un.remove_edges_from(nx.selfloop_edges(G_un))
    G_un.remove_edges_from([(u,v) for u,v,d in G_un.edges(data=True) if d["weight"] == 0])

    communities = louvain_communities(G, resolution=com_resolution)
    communities = list(communities)
    cmap = sns.color_palette(cmap, as_cmap=True)
    edge_cmap = sns.color_palette(edge_cmap, as_cmap=True)

    avg_weight = {n : np.median([d[2]['weight'] for d in G.out_edges(n, data=True)]) for n in G.nodes()}
    avg_income = {n : np.median([d[2]['weight'] for d in G.in_edges(n, data=True)]) for n in G.nodes()}
    node_to_community = {node: i for i, community in enumerate(communities) for node in community}

    node_color = {node: cmap(avg_weight[node]) for node in G.nodes()}
    node_edge_color = {node: cmap(avg_income[node]) for node in G.nodes()}
    node_labels = {node: node for node in G.nodes()}
    edge_color = {edge: edge_cmap(G_un.edges[edge]['weight']) for edge in G_un.edges()}

    # normalize edge alpha
    edge_alpha = {edge: G_un.edges[edge]['weight'] for edge in G_un.edges()}
    min_edge = min(edge_alpha.values())
    edge_alpha = {edge: ((edge_alpha[edge] - min_edge) / (max(edge_alpha.values()) - min_edge)) * (max_alpha - min_alpha) + min_alpha for edge in edge_alpha}

    # edge width

    edge_width = {edge: G_un.edges[edge]['weight'] for edge in G_un.edges()}
    min_edge = np.quantile(list(edge_width.values()), clip_min_values_width)
    edge_width = {edge: (edge_width[edge] -min_edge) / (max(edge_width.values()) - min_edge)**edge_pow * (max_edge_width - min_edge_width) + min_edge_width for edge in edge_width}



    fig, ax = plt.subplots(figsize=(figsize, figsize))
    if node_layout == "community":
        node_layout_kwargs = dict(node_to_community=node_to_community, pad_by=com_pad_by)
    else:
        node_layout_kwargs = {}



    graph = Graph(G_un, node_layout_kwargs=node_layout_kwargs, node_layout=node_layout, node_color=node_color, node_labels=node_labels, edge_color=edge_color, ax=ax, node_label_fontdict={'fontsize': fontsize, 'fontweight': 'bold'}, node_edge_color=node_edge_color, edge_layout=edge_layout, edge_alpha=edge_alpha, arrows=not undirected, prettify=True, edge_width=edge_width)


    # add white contour to all texts in the figure
    for text in plt.gca().texts:
        text.set_path_effects([patheffects.Stroke(linewidth=4, foreground='white'), patheffects.Normal()])

In [None]:
models_to_cons = [
    "ContextPred",
    "GPT-GNN",
    "GraphMVP",
    "GROVER",
    "AttributeMask",
    "GraphLog",
    "GraphCL",
    "InfoGraph",
    "MolBert",
    "ChemBertMLM-10M",
    "ChemBertMTR-77M",
    "ChemGPT-1.2B",
    "DenoisingPretrainingPQCMv4",
    "FRAD_QM9",
    "MolR_tag",
    "MoleOOD_OGB_GCN",
    "ThreeDInfomax",
]

plot_com(
    df[
        (df.ff_layers == 2) & (df.X.isin(models_to_cons)) & (df.Y.isin(models_to_cons))
     ],
    figsize=10,
    edge_layout="straight",
    com_resolution=1.2,
    min_edge_width=0.5,
    max_edge_width=1,
    min_alpha=0.2,
    max_alpha=1,
    clip_min_values_alpha=0.3,
    cmap="vlag",
    edge_cmap="flare",
    fontsize=5,
    com_pad_by=0.001,
    sparsity=0.9
)

#plt.savefig("fig/MI_graph.pdf", format = "pdf", bbox_inches = 'tight')

In [None]:
plot_com(
    df[
        (df.ff_layers == 3) & (df.X.isin(models_to_cons)) & (df.Y.isin(models_to_cons))
     ],
    figsize=10,
    edge_layout="straight",
    com_resolution=1.2,
    min_edge_width=0.5,
    max_edge_width=1,
    min_alpha=0.2,
    max_alpha=1,
    clip_min_values_alpha=0.3,
    cmap="vlag",
    edge_cmap="flare",
    fontsize=5,
    com_pad_by=0.001,
    sparsity=0.9
)

#plt.savefig("fig/MI_graph.pdf", format = "pdf", bbox_inches = 'tight')

In [None]:
plot_com(
    df[
        (df.ff_layers == 4) & (df.X.isin(models_to_cons)) & (df.Y.isin(models_to_cons))
     ],
    figsize=10,
    edge_layout="straight",
    com_resolution=1.2,
    min_edge_width=0.5,
    max_edge_width=1,
    min_alpha=0.2,
    max_alpha=1,
    clip_min_values_alpha=0.3,
    cmap="vlag",
    edge_cmap="flare",
    fontsize=5,
    com_pad_by=0.001,
    sparsity=0.9
)

#plt.savefig("fig/MI_graph.pdf", format = "pdf", bbox_inches = 'tight')

In [None]:
plot_com(
    df,
    figsize=10,
    edge_layout="straight",
    com_resolution=0.99,
    min_edge_width=1,
    max_edge_width=2,
    min_alpha=0.2,
    max_alpha=1,
    clip_min_values_alpha=0.5,
    cmap="vlag",
    edge_cmap="flare",
    fontsize=5,
    com_pad_by=0.001,
    sparsity=0.9
)

In [None]:
df

In [None]:
df_avg = df.groupby(["X", COLUMS_SPLIT]).median()
df_avg["information"] = df_avg["I(X->Y)/dim"]
fig, ax  = plt.subplots(figsize=(6,6))
sns.barplot(df_avg.sort_values("information") ,x="information", hue=COLUMS_SPLIT, legend=False, ax = ax, y="X")

plt.title("Mean predictive mutual information of a model to predict the other models")
plt.ylabel("")
plt.xlabel("Mean predictive mutual information")
plt.savefig("fig/mean_information.pdf", format = "pdf", bbox_inches = 'tight')

In [None]:
df_downs = pd.read_csv("results/tmp.csv")
df_avg.rename(columns={"X":"embedder"}, inplace=True)


In [None]:
df_downs = df_downs.join(df_avg.reset_index().set_index("X")[["information", COLUMS_SPLIT]], on="embedder")
df_downs

In [None]:
df_downs[(df_downs.dataset == "DILI") & (df_downs.embedder == "ChemBertMLM-10M")]

In [None]:
from autorank import autorank
#rank each model on each task


In [None]:
df_to_rank = df_downs.pivot_table(index="dataset", columns=["embedder"], values="metric", aggfunc="mean")
res = autorank(df_to_rank, alpha=0.05, verbose=False, force_mode="nonparametric").rankdf
res.rename(columns={"meanrank":"meanrank_metric"}, inplace=True)
res

In [None]:
df_downs = df_downs.join(res[["meanrank_metric"]], on=["embedder"])

In [None]:
df_downs["meanrank_information"] = np.nan
for x in df_downs[COLUMS_SPLIT].unique():
    df_to_rank = df_downs[df_downs[COLUMS_SPLIT] == x].pivot_table(index="dataset", columns="embedder", values="information", aggfunc="mean")
    res = autorank(df_to_rank, alpha=0.05, verbose=False, force_mode="nonparametric").rankdf
    res.rename(columns={"meanrank":"meanrank_information"}, inplace=True)
    df_downs.loc[df_downs[COLUMS_SPLIT] == x, "meanrank_information"] = df_downs[df_downs[COLUMS_SPLIT] == x].embedder.map(res["meanrank_information"])


In [None]:
df_downs

In [None]:
n_rows = 5
n_cols = df_downs.dataset.nunique() // n_rows

fig, axes = plt.subplots(n_rows,n_cols, figsize=(3*n_cols,3*n_rows), sharey=True)
axes = axes.flatten()


for i,dataset in enumerate(df_downs.dataset.unique()):
    df_tmp = df_downs[df_downs.dataset == dataset]
    # compute ranking for roc
    sns.scatterplot(data=df_downs[df_downs.dataset == dataset], x="metric", y="information", hue="embedder", ax=axes[i], legend=False)
    # add linear regression
    sns.regplot(data=df_tmp, x="metric", y="information", ax=axes[i], scatter=False, color="blue", line_kws = {"alpha":0.2})
    axes[i].set_title(dataset)

fig.tight_layout()
plt.savefig("fig/roc_vs_information.pdf", format = "pdf", bbox_inches = 'tight')

In [None]:
%matplotlib inline

fig, axes = plt.subplots(1,df_downs[COLUMS_SPLIT].nunique(), figsize=(5*df_downs[COLUMS_SPLIT].nunique(),5))


for i,ff_layers in enumerate(df_downs[COLUMS_SPLIT].unique()):
    df_tmp = df_downs[df_downs[COLUMS_SPLIT] == ff_layers]
    df_tmp = df_tmp.groupby("embedder").mean()
    sns.scatterplot(data=df_tmp, x="meanrank_metric", y="meanrank_information", hue="embedder", ax=axes[i], legend=False)
    sns.regplot(data=df_tmp, x="meanrank_metric", y="meanrank_information", ax=axes[i], scatter=False, color="blue", line_kws = {"alpha":0.2})

    # Display the correlation coefficient
    corr = np.corrcoef(df_tmp["meanrank_metric"], df_tmp["meanrank_information"])[0,1]
    axes[i].text(0.5, 0.5, f"Correlation: {corr:.4f}", horizontalalignment='center', verticalalignment='center', transform=axes[i].transAxes)

    axes[i].set_title(f"{COLUMS_SPLIT}: {ff_layers}")
    axes[i].set_ylabel("Mean rank of models' predictivity")
    axes[i].set_xlabel("Mean rank on downstream tasks")


plt.savefig("fig/meanrank.pdf", format = "pdf", bbox_inches = 'tight')

In [None]:
n_rows = 3
n_cols = df_downs.dataset.nunique() // n_rows

fig, axes = plt.subplots(n_rows,n_cols, figsize=(3*n_cols,3*n_rows), sharey=True)
axes = axes.flatten()



for i,dataset in enumerate(df_downs.dataset.unique()):
    sns.scatterplot(data=df_downs.sort_values("information")[df_downs.dataset == dataset], x="metric", y="embedder", hue="embedder",  ax=axes[i], legend=False)
    axes[i].set_title(dataset)
    axes[i].set_xticklabels(axes[i].get_xticklabels(), rotation=90)

fig.tight_layout()

In [None]:
%matplotlib inline
fig = plt.figure(figsize=(6,6))
sns.scatterplot(data=df_downs.groupby("embedder").mean().sort_values("information"), y="embedder", x="meanrank_metric", hue="embedder", legend=False)

plt.savefig("fig/meanrank_roc.pdf", format = "pdf", bbox_inches = 'tight')