In [None]:
import os
from functools import partial
from collections import namedtuple
from typing import Tuple, List

import numpy as np

import torch
from torch_geometric.data import DataLoader
import datamol as dm

import matplotlib.pyplot as plt

from emir.estimators import KNIFEEstimator, KNIFEArgs

from models.moleculenet_models import GNN, GNN_graphpred
from moleculenet_encoding import mol_to_graph_data_obj_simple
from utils import get_embeddings_from_model, get_molfeat_descriptors,get_molfeat_transformer

from tqdm import tqdm

MODEL_PARAMS = {
    "num_layer": 5,
    "emb_dim": 300,
    "JK": "last",
    "drop_ratio": 0.5,
    "gnn_type": "gin",
}


In [None]:
from tdc.utils import retrieve_label_name_list
from tdc.single_pred import Tox
df = Tox(name = 'hERG_Karim').get_data()
df

In [None]:
mols = None
smiles = df["Drug"].tolist()


In [None]:
mols = [
    dm.conformers.generate(dm.to_mol(s), align_conformers=True, n_confs=5) for s in tqdm(smiles, desc="Generating conformers")
]

mols[0]

In [None]:
transformer, thrD = get_molfeat_transformer("usr")
feat, valid_id = transformer(mols, ignore_errors=True)
smiles = np.array(smiles)[valid_id]
mols = np.array(mols)[valid_id]

In [None]:
dataloader = DataLoader(
    [mol_to_graph_data_obj_simple(dm.to_mol(s)) for s in tqdm(smiles)],
    batch_size=32,
    shuffle=False
)

In [None]:
MODEL_PATH = "backbone_pretrained_models"
MODELS = {}
# For every directory in the folder
for model_name in os.listdir(MODEL_PATH):
    # For every file in the directory
    for file_name in os.listdir(os.path.join(MODEL_PATH, model_name)):
        # If the file is a .pth file
        if file_name.endswith(".pth"):
            MODELS[model_name] = os.path.join(MODEL_PATH, model_name, file_name)

In [None]:
descriptors = ["ecfp", "rdkit", "topological", "scaffoldkeys",]

embeddings_fn = {}
for model_name, model_path in MODELS.items():
    embeddings_fn[model_name] = partial(
        get_embeddings_from_model,
        path=model_path
    )
for method in descriptors:
    embeddings_fn[method] = partial(
        get_molfeat_descriptors,
        transformer_names=[method]
    )

In [None]:
Knige_config = KNIFEArgs(
    cond_modes=3,
    marg_modes=3,
    lr=0.01,
    batch_size=128,
    device = "cpu",
    n_epochs=30,
    ff_layers=2,
    cov_diagonal = "var",
    cov_off_diagonal=""
)

In [None]:
def get_knife_preds(key1: str, key2:str) -> Tuple[float,float,float, List[float]]:
    if key1 in MODELS.keys():
         x1 = embeddings_fn[key1](dataloader)
    else:
        x1 = embeddings_fn[key1](smiles, mols = mols)[0]
    if key2 in MODELS.keys():
        x2 = embeddings_fn[key2](dataloader)
    else:
        x2 = embeddings_fn[key2](smiles, mols = mols)[0]
    print(x1.shape[1], x2.shape[1])
    knife_estimator = KNIFEEstimator(Knige_config,x1.shape[1], x2.shape[1])
    mi,m,c = knife_estimator.eval(x1.float(), x2.float(), record_loss = True)
    return mi, m, c, knife_estimator.recorded_loss

def get_knife_preds_plot_loss(key1:str, key2:str):
    mi, m, c, recorded_loss = get_knife_preds(key1, key2)
    plt.plot(recorded_loss)
    plt.title(f"{key1} vs {key2}\nMI: {mi:.3f}, M: {m:.3f}, C: {c:.3f}")
    plt.show()


In [None]:
get_knife_preds_plot_loss("GROVER","ecfp",)

In [None]:
get_knife_preds_plot_loss("ecfp","GROVER",)

In [None]:
get_knife_preds_plot_loss("scaffoldkeys","GROVER",)

In [None]:
from itertools import product
import pandas as pd

def model_profile(model_name):
    results = {
        "desc1": [],
        "desc2": [],
        "mi": [],
    }
    for desc in descriptors:
        mi ,_,_,_ = get_knife_preds(model_name, desc)
        print(f"{model_name} vs {desc}: {mi:.3f}")
        results["desc1"].append(model_name)
        results["desc2"].append(desc)
        results["mi"].append(mi)
    return pd.DataFrame(results)

In [None]:
MODELS.keys()

In [None]:
results_GROVER = model_profile("GROVER")

In [None]:
results_GRAPHMVP = model_profile("GraphMVP")

In [None]:
results_attm = model_profile("AttributeMask")

In [None]:
full_df = pd.concat([results_GRAPHMVP, results_attm,])

In [None]:
import seaborn as sns

In [None]:
sns.barplot(data=full_df, x="desc2", y="mi", hue="desc1")