# Knife MI analysis

In [None]:
import os

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm as tqdm
import numpy as np
import json
import sklearn
from sklearn.decomposition import PCA

from utils import MolecularFeatureExtractor
from models.model_paths import get_model_path

In [None]:
DATASET = "ZINC"
LENGTH = 1024
MDS_DIM = 100

DESCRIPTORS = [
            "ecfp",
            "estate",
            "fcfp",
            "erg",
            "rdkit",
            "topological",
            "avalon",
            "maccs",
            "secfp",
            "scaffoldkeys",
            "cats",
            "gobbi",
            "pmapper",
            "cats/3D",
            "gobbi/3D",
            "pmapper/3D",
            "ScatteringWavelet",
        ]
MODELS = [
        "ContextPred",
        "GPT-GNN",
        "GraphMVP",
        "GROVER",
        # "EdgePred", # This model is especially bad and makes visualization hard
        "AttributeMask",
        "GraphLog",
        "GraphCL",
        "InfoGraph",
        "Not-trained",
        "MolBert",
        "ChemBertMLM-5M",
        "ChemBertMLM-10M",
        "ChemBertMLM-77M",
        "ChemBertMTR-5M",
        "ChemBertMTR-10M",
        "ChemBertMTR-77M",
        "ChemGPT-1.2B",
        "ChemGPT-19M",
        "ChemGPT-4.7M",
        "DenoisingPretrainingPQCMv4",
        "FRAD_QM9",
        "MolR_gat",
        "MolR_gcn",
        "MolR_tag",
        "MoleOOD_OGB_GIN",
        "MoleOOD_OGB_GCN",
        "MoleOOD_OGB_SAGE",
        "ThreeDInfomax",
    ]

MODELS_PATH = get_model_path(models=MODELS)


with open(f"data/{DATASET}/smiles.json", "r") as f:
    smiles = json.load(f)

if not len(smiles) > 50000:
    import datamol as dm
    mols = dm.read_sdf(f"data/{DATASET}/preprocessed.sdf")

    feature_extractor = MolecularFeatureExtractor(dataset=DATASET, length=LENGTH, mds_dim=MDS_DIM, device="cuda")
    # same plots but in 3D

    fig,axes = plt.subplots(3,len(MODELS)//3,figsize=(len(MODELS)//3*5,3*5), subplot_kw={'projection': '3d'})
    axes = axes.flatten()

    for i,model in enumerate(MODELS):
        embeddings = feature_extractor.get_features(smiles, mols = mols, name=model,feature_type="model",path = MODELS_PATH.get(model, None))
        # nromalize embeddings
        embeddings = (embeddings - embeddings.mean(axis=0))/(embeddings.std(axis=0) +1e-8)
        pca = PCA(n_components=3)
        embeddings_pca = pca.fit_transform(embeddings.cpu())
        df = pd.DataFrame(embeddings_pca, columns=[f"PC{i}" for i in range(1,4)])
        df["smiles"] = smiles
        # using pyplot
        axes[i].scatter3D(df["PC1"], df["PC2"], df["PC3"], c=df["PC1"], cmap='viridis', alpha=0.1)
        axes[i].set_title(model)


In [None]:
full_df_loss_cond = []
full_df_loss_marg = []
RESULTS_PATH = f"results/{DATASET}/{LENGTH}/{MDS_DIM}"
dir_path = os.path.join(RESULTS_PATH, "losses")

for file in tqdm(os.listdir(dir_path)):
    if file.endswith(".csv") and file[:-4].split("_")[0] == DATASET:
        file_split = file[:-4].split("_")
        if file_split[-1] == "marg":
            df_tmp = pd.read_csv(os.path.join(dir_path, file))
            full_df_loss_marg.append(df_tmp)
        else:
            df_tmp = pd.read_csv(os.path.join(dir_path, file))
            full_df_loss_cond.append(df_tmp)

full_df_loss_cond = pd.concat(full_df_loss_cond)
full_df_loss_marg = pd.concat(full_df_loss_marg)



In [None]:
full_df_loss_cond

In [None]:
n_rows = np.ceil(full_df_loss_cond.Y.nunique()/5).astype(int)+1
n_cols = 5
fig, axes = plt.subplots(n_rows,n_cols,figsize=(n_cols*4,4*n_rows))
axes = axes.flatten()
for i, model in enumerate(tqdm(full_df_loss_marg.X.unique())):
    df_tmp = full_df_loss_marg[full_df_loss_marg.X == model]
    sns.lineplot(data=df_tmp, x="epoch", y="marg_ent", hue="X", ax=axes[i], estimator=None, errorbar= None, n_boot=0, legend=False)
    axes[i].set_title(model)
    axes[i].set_xlabel("")
    axes[i].set_ylabel("Marginal entropy")

In [None]:
# remove EdgePred
full_df_loss_cond = full_df_loss_cond[full_df_loss_cond.Y != "EdgePred"]
full_df_loss_cond = full_df_loss_cond[full_df_loss_cond.X != "EdgePred"]

full_df_loss_margin = full_df_loss_marg[full_df_loss_marg.X != "EdgePred"]

In [None]:
n_rows = full_df_loss_cond.Y.nunique()
n_cols = 2
fig, axes = plt.subplots(n_rows,n_cols,figsize=(n_cols*4,4*n_rows))


for i, model in enumerate(tqdm(full_df_loss_cond.Y.unique())):
    df_tmp = full_df_loss_cond[(full_df_loss_cond.Y == model) & (full_df_loss_cond.direction == "X->Y")]
    sns.lineplot(data=df_tmp, x="epoch", y="cond_ent", hue="X", ax=axes[i,0], estimator=None, errorbar= None, n_boot=0, legend=False)
    axes[i,0].set_title(model)
    axes[i,0].set_xlabel("")
    axes[i,0].set_ylabel("X->Y H(Y|X)")

    df_tmp = full_df_loss_cond[(full_df_loss_cond.Y == model) & (full_df_loss_cond.direction == "Y->X")]

    sns.lineplot(data=df_tmp, x="epoch", y="cond_ent", hue="X", ax=axes[i,1], estimator=None, errorbar= None, n_boot=0, legend=False)
    axes[i,1].set_title(model)
    axes[i,1].set_xlabel("")
    axes[i,1].set_ylabel("Y->X H(X|Y)")


## MI between descriptors and embeddings

In [None]:
import os
import numpy as np

all_df = []
for file in os.listdir(RESULTS_PATH):
    if file.endswith(".csv"):
        file_split = file[:-4].split("_")
        if file_split[0] == DATASET and file_split[-1] == str(LENGTH):
            all_df.append(pd.read_csv(os.path.join(RESULTS_PATH, file)))
df = pd.concat(all_df)
df.X = df.X.apply(lambda x: x.replace(f"{LENGTH}", ""))

#df =df[df.Y.isin(DESCRIPTORS)]
df

In [None]:
# Complete the I(X|Y) type columns by switching X and Y
df_tmp = df.copy()
df_tmp["X"] = df.Y
df_tmp["Y"] = df.X

df_tmp["I(Y->X)"] = df["I(X->Y)"]
df_tmp["I(X)"] = df["I(Y)"]
df_tmp["I(X|Y)"] = df["I(Y|X)"]


df = df[["X","Y", "I(Y)", "I(Y|X)", "I(X->Y)", "X_dim", "Y_dim"]].join(
    df_tmp[["X", "Y", "I(X)", "I(X|Y)", "I(Y->X)"]].set_index(
        ["X", "Y"]
    ), on=["X", "Y"]
)
df


## Clustermap

In [None]:
df["I(Y->X)/dim"] = df["I(Y->X)"]/df["X_dim"]
df["I(X->Y)/dim"] = df["I(X->Y)"]/df["Y_dim"]

df["I(Y->X)/I(X)"] = df["I(Y->X)"]/df["I(X)"]
df["I(X->Y)/I(Y)"] = df["I(X->Y)"]/df["I(Y)"]

df["I(Y->X)/I(X)"] = df["I(Y->X)"]/df["I(X)"]
df["I(X->Y)/I(Y)"] = df["I(X->Y)"]/df["I(Y)"]



df["I(Y->X) - I(X->Y)"] =  df["I(Y->X)"]-df["I(X->Y)"]
df["I(Y->X)/dim - I(X->Y)/dim"] = df["I(Y->X)/dim"]-df["I(X->Y)/dim"]
df["I(Y->X)/I(X) - I(X->Y)/I(Y)"] = df["I(Y->X)/I(X)"]-df["I(X->Y)/I(Y)"]

df["I(Y->X) / I(X->Y)"] =  df["I(Y->X)"]/(df["I(X->Y)"] + 1e-8)
df["I(Y->X)/dim / I(X->Y)/dim"] = df["I(Y->X)/dim"]/(df["I(X->Y)/dim"] + 1e-8)
df["I(Y->X)/I(X) / I(X->Y)/I(Y)"] = df["I(Y->X)/I(X)"]/(df["I(X->Y)/I(Y)"] + 1e-8)

df

In [None]:
df = df[(df.Y!= "EdgePred") & ("EdgePred" != df.X)]

In [None]:
keys = [
    "I(Y->X)", "I(X->Y)", "I(Y->X)/dim", "I(X->Y)/dim",
    "I(Y->X) - I(X->Y)", "I(Y->X)/dim - I(X->Y)/dim",
]

In [None]:
%matplotlib inline

fig,axes = plt.subplots(1,3,figsize=(15, 5), sharex=True)

sns.scatterplot(data=df, y="I(Y->X)", x="X_dim", hue="X", ax=axes[0], legend=False)

sns.scatterplot(data=df, y="I(Y->X)/dim", x="X_dim", hue="X", ax=axes[1], legend=False)

sns.scatterplot(data=df, y="I(Y->X)/I(X)", x="X_dim", hue="X", ax=axes[2], legend=False)


In [None]:
%matplotlib inline

fig,axes = plt.subplots(1,3,figsize=(15, 5), sharex=True)

sns.scatterplot(data=df, y="I(X->Y)", x="Y_dim", hue="Y", ax=axes[0], legend=False)
sns.scatterplot(data=df, y="I(X->Y)/dim", x="Y_dim", hue="Y", ax=axes[1], legend=False)
sns.scatterplot(data=df, y="I(X->Y)/I(Y)", x="Y_dim", hue="Y", ax=axes[2], legend=False)


In [None]:
df = df[df.X.isin(df.Y.unique())]
df

In [None]:
from scipy.cluster.hierarchy import linkage

def plot_cmap(df, keys, cmap = "copper", vmin = None, vmax = None, center = None, values=False):
    if vmax is None:
        vmax = [None]*len(keys)
    if vmin is None:
        vmin = [None]*len(keys)
    for i, key in enumerate(keys):
        df_pivot = df.pivot_table(index="X", columns="Y", values=key, aggfunc="mean")

        link = linkage(df_pivot, method="ward")
        cluster = sns.clustermap(
            df_pivot, row_linkage=link, col_linkage=link,
            cmap=cmap, figsize=(8,8), vmin=vmin[i], vmax=vmax[i], center=center, annot=values
        )
        cluster.savefig("fig/cluster_{}.png".format(i))
        plt.clf()

    import matplotlib.image as mpimg
    fig, axes = plt.subplots(1,len(keys), figsize=(8*len(keys),8))
    for i, key in enumerate(keys):
        axes[i].imshow(mpimg.imread("fig/cluster_{}.png".format(i)))
        axes[i].axis("off")
        axes[i].set_title(key)
    plt.show()

In [None]:
%matplotlib inline
plot_cmap(
    df,
    ["I(X->Y)", "I(X->Y)/dim"],
    cmap="viridis",
)

In [None]:
plot_cmap(
    df,
    ["I(Y->X) - I(X->Y)", "I(Y->X)/dim - I(X->Y)/dim"],
    cmap="seismic",
    center=0
)

In [None]:
models_to_cons = [
        "ContextPred",
        "GPT-GNN",
        "GraphMVP",
        "GROVER",
        # "EdgePred", # This model is especially bad and makes visualization hard
        "AttributeMask",
        "GraphLog",
        "GraphCL",
        "InfoGraph",
        "Not-trained",
        "MolBert",
        "ChemBertMLM-10M",
        "ChemBertMTR-77M",
        "ChemGPT-1.2B",
        "DenoisingPretrainingPQCMv4",
        "FRAD_QM9",
        "MolR_tag",
        "MoleOOD_OGB_GIN",
        "ThreeDInfomax",
    ]

plot_cmap(
    df[
        (df.X.isin(models_to_cons)) & (df.Y.isin(models_to_cons))
    ],
    ["I(X->Y)", "I(X->Y)/dim"],
    cmap="rocket",
)

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patheffects as patheffects

sns.set_style("whitegrid")




In [None]:
from netgraph import Graph, InteractiveGraph
from networkx.algorithms.community import girvan_newman, modularity_max, louvain_communities

def plot_com(df, cmap="crest", min_alpha = 0.0,
    max_alpha = 1.0, alpha_pow=2, min_edge_width = 0.0, max_edge_width = 1, edge_pow=2, node_layout="community", edge_layout="bundled", com_resolution=1.1, figsize=10, clip_min_values_width = 0, clip_min_values_alpha = 0, com_pad_by = 0.001, fontsize = 15):
    table = df.pivot_table(index="X", columns="Y", values="I(X->Y)/dim", aggfunc="mean")
    # remove lines and columns containing xsum in index and columns

    # compute 1/x for each value

    G= nx.from_pandas_adjacency(table, create_using=nx.DiGraph)
    G.remove_edges_from(nx.selfloop_edges(G))

    communities = louvain_communities(G, resolution=com_resolution)
    communities = list(communities)
    cmap = sns.color_palette(cmap, as_cmap=True)

    G= nx.from_pandas_adjacency(table, create_using=nx.DiGraph)
    G.remove_edges_from(nx.selfloop_edges(G))

    avg_weight = {n : np.median([d[2]['weight'] for d in G.out_edges(n, data=True)]) for n in G.nodes()}
    avg_income = {n : np.median([d[2]['weight'] for d in G.in_edges(n, data=True)]) for n in G.nodes()}
    node_to_community = {node: i for i, community in enumerate(communities) for node in community}

    node_color = {node: cmap(avg_weight[node]) for node in G.nodes()}
    node_edge_color = {node: cmap(avg_income[node]) for node in G.nodes()}
    node_labels = {node: node for node in G.nodes()}
    edge_color = {edge: cmap(G.edges[edge]['weight']) for edge in G.edges()}

    # normalize edge alpha
    min_alpha = 0.0
    max_alpha = 1.0
    edge_alpha = {edge: G.edges[edge]['weight'] for edge in G.edges()}
    min_edge = np.quantile(list(edge_alpha.values()), clip_min_values_alpha)
    edge_alpha = {edge: ((edge_alpha[edge] - min_edge) / (max(edge_alpha.values()) - min_edge))**alpha_pow * (max_alpha - min_alpha) + min_alpha for edge in edge_alpha}

    # edge width

    edge_width = {edge: G.edges[edge]['weight'] for edge in G.edges()}
    min_edge = np.quantile(list(edge_width.values()), clip_min_values_width)
    edge_width = {edge: (edge_width[edge] -min_edge) / (max(edge_width.values()) - min_edge)**edge_pow * (max_edge_width - min_edge_width) + min_edge_width for edge in edge_width}



    fig, ax = plt.subplots(figsize=(figsize, figsize))
    if node_layout == "community":
        node_layout_kwargs = dict(node_to_community=node_to_community, pad_by=com_pad_by)
    else:
        node_layout_kwargs = {}
    graph = Graph(G, node_layout_kwargs=node_layout_kwargs, node_layout=node_layout, node_color=node_color, node_labels=node_labels, edge_color=edge_color, ax=ax, node_label_fontdict={'fontsize': fontsize, 'fontweight': 'bold'}, node_edge_color=node_edge_color, edge_layout=edge_layout, edge_alpha=edge_alpha, arrows=True, prettify=True)


    # add white contour to all texts in the figure
    for text in plt.gca().texts:
        text.set_path_effects([patheffects.Stroke(linewidth=4, foreground='white'), patheffects.Normal()])

In [None]:
plot_com(
    df,
    figsize=15,
    com_resolution=1.2,
    min_edge_width=0,
    max_edge_width=.1,
    clip_min_values_alpha=0.1,
    cmap="flare",
    fontsize=10,
    com_pad_by=0.001
)
plt.savefig("fig/MI_graph.pdf", format = "pdf", bbox_inches = 'tight')

In [None]:

plot_com(
    df[
        (df.X.isin(models_to_cons)) & (df.Y.isin(models_to_cons))
     ],
    figsize=15,
    com_resolution=1.2,
    min_edge_width=0,
    max_edge_width=.1,
    clip_min_values_alpha=0.1,
    cmap="flare",
    fontsize=10
)


In [None]:
plot_com(df[
             (df.X.isin(models_to_cons)) & (df.Y.isin(models_to_cons))
         ], edge_layout="straight", figsize=15, com_resolution=1.15, max_edge_width=.7, clip_min_values_alpha=0.5)

In [None]:
df


In [None]:
%matplotlib inline
plt.figure(figsize=(6,4))
sns.scatterplot(data=df[df.X == "Not-trained"], x="I(Y->X)/dim", y="I(X->Y)/dim", hue="Y", legend=False)

plt.title("Predictive mutual information, X is a models that hasn't been trained")
plt.xlabel("I(Y->X)")
plt.ylabel("I(X->Y)")
plt.plot([0, 1.5], [0, 1.5], "r--")
plt.xlim(0.2,0.7)
plt.ylim(0,0.7)
plt.show()

In [None]:
%matplotlib inline
plt.figure(figsize=(6,4))
df_tmp = df[df.X == "DenoisingPretrainingPQCMv4"]

sns.scatterplot(data=df_tmp, x="I(Y->X)/dim", y="I(X->Y)/dim", hue="Y", legend=False)
sns.scatterplot(data=df_tmp[df_tmp.Y == "FRAD_QM9"], x="I(Y->X)/dim", y="I(X->Y)/dim", color="black")

plt.title("Predictive mutual information, X is a denosing 3D model")
plt.xlabel("I(Y->X)")
plt.ylabel("I(X->Y)")
plt.plot([0, 1.5], [0, 1.5], "r--")
plt.xlim(0,0.4)
plt.ylim(0.1,0.7)
plt.show()

In [None]:
df_avg = df.groupby("Y").mean()
df_avg["information"] = df_avg["I(Y->X)/dim"]
fig, ax  = plt.subplots(figsize=(6,6))
sns.barplot(df_avg.sort_values("information") ,x="information", hue="Y", legend=False, ax = ax, y="Y")

plt.title("Mean predictive mutual information of a model to predict the other models")
plt.ylabel("")
plt.xlabel("Mean predictive mutual information")
plt.savefig("fig/mean_information.pdf", format = "pdf", bbox_inches = 'tight')

In [None]:
df_downs = pd.read_csv("results/tmp.csv")
df_avg.rename(columns={"Y":"embedder"}, inplace=True)
df_downs = df_downs.join(df_avg["information"], on="embedder")
df_downs.sample(3)

In [None]:
df_downs[(df_downs.dataset == "DILI") & (df_downs.embedder == "ChemBertMLM-10M")]

In [None]:
from autorank import autorank
#rank each model on each task


In [None]:
df_to_rank = df_downs.pivot_table(index="dataset", columns="embedder", values="metric", aggfunc="mean")

df_to_rank

In [None]:
res = autorank(df_to_rank, alpha=0.05, verbose=False, force_mode="nonparametric").rankdf

df_downs["meanrank_metric"] = df_downs.embedder.apply(lambda x: res.loc[x].meanrank)

In [None]:
df_to_rank = df[["Y","I(Y->X)", "X"]].pivot_table(index="X",columns="Y", values="I(Y->X)", aggfunc="mean")

res = autorank(df_to_rank, alpha=0.05, verbose=False).rankdf

df_downs["meanrank_information"] = df_downs.embedder.apply(lambda x: res.loc[x].meanrank)

In [None]:
df_downs

In [None]:
n_rows = 3
n_cols = df_downs.dataset.nunique() // n_rows

fig, axes = plt.subplots(n_rows,n_cols, figsize=(3*n_cols,3*n_rows), sharey=True)
axes = axes.flatten()


for i,dataset in enumerate(df_downs.dataset.unique()):
    df_tmp = df_downs[df_downs.dataset == dataset]
    # compute ranking for roc
    sns.scatterplot(data=df_downs[df_downs.dataset == dataset], x="metric", y="information", hue="embedder", ax=axes[i], legend=False)
    # add linear regression
    sns.regplot(data=df_tmp, x="metric", y="information", ax=axes[i], scatter=False, color="blue", line_kws = {"alpha":0.2})
    axes[i].set_title(dataset)

fig.tight_layout()
plt.savefig("fig/roc_vs_information.pdf", format = "pdf", bbox_inches = 'tight')

In [None]:
sns.scatterplot(data=df_downs.groupby("embedder").mean(), x="metric", y="information", hue="embedder", legend=False)
sns.regplot(data=df_downs.groupby("embedder").mean(), x="metric", y="information", scatter=False, color="blue", line_kws = {"alpha":0.2})

In [None]:
fig = plt.figure(figsize=(3,3))

sns.scatterplot(data=df_downs.groupby("embedder").mean(), x="meanrank_information", y="meanrank_metric", hue="embedder", legend=False)
sns.regplot(data=df_downs.groupby("embedder").mean(), x="meanrank_information", y="meanrank_metric", scatter=False, color="blue", line_kws = {"alpha":0.2})

plt.xlabel("Mean rank of models' predictivity")
plt.ylabel("Mean rank on downstream tasks")

plt.savefig("fig/meanrank.pdf", format = "pdf", bbox_inches = 'tight')

In [None]:
n_rows = 3
n_cols = df_downs.dataset.nunique() // n_rows

fig, axes = plt.subplots(n_rows,n_cols, figsize=(3*n_cols,3*n_rows), sharey=True)
axes = axes.flatten()



for i,dataset in enumerate(df_downs.dataset.unique()):
    sns.scatterplot(data=df_downs.sort_values("information")[df_downs.dataset == dataset], x="metric", y="embedder", hue="embedder",  ax=axes[i], legend=False)
    axes[i].set_title(dataset)
    axes[i].set_xticklabels(axes[i].get_xticklabels(), rotation=90)

fig.tight_layout()

In [None]:
%matplotlib inline
fig = plt.figure(figsize=(6,6))
sns.scatterplot(data=df_downs.groupby("embedder").mean().sort_values("information"), y="embedder", x="meanrank_metric", hue="embedder", legend=False)

plt.savefig("fig/meanrank_roc.pdf", format = "pdf", bbox_inches = 'tight')