In [None]:
import os
from collections import namedtuple
from typing import Tuple, List
from models.moleculenet_models import GNN, GNN_graphpred
from data.moleculenet_encoding import mol_to_graph_data_obj_simple
import datamol as dm
import torch
from torch_geometric.data import DataLoader
import torch_geometric.nn.pool as tgp
import matplotlib.pyplot as plt

from emir.estimators import KNIFEEstimator, KNIFEArgs


MODEL_PARAMS = {
    "num_layer": 5,
    "emb_dim": 300,
    "JK": "last",
    "drop_ratio": 0.5,
    "gnn_type": "gin",
}
df = dm.data.freesolv()
df

In [None]:

dataloader = DataLoader(
    [mol_to_graph_data_obj_simple(dm.to_mol(smiles))for smiles in df["smiles"]],
    batch_size=32,
    shuffle=False
)

In [None]:
@torch.no_grad()
def get_embeddings_from_model(
        path:str = "backbone_pretrained_models/GROVER/grover.pth",
        pooling_method = tgp.global_mean_pool
):
    embeddings = []
    molecule_model = GNN(**MODEL_PARAMS)
    molecule_model.load_state_dict(torch.load(path))
    for b in dataloader:
        embeddings.append(
            pooling_method(molecule_model(b.x, b.edge_index, b.edge_attr), b.batch)
        )
    embeddings = torch.cat(embeddings, dim=0)
    return embeddings

In [None]:
MODEL_PATH = "backbone_pretrained_models"
MODELS = {}
# For every directory in the folder
for model_name in os.listdir(MODEL_PATH):
    # For every file in the directory
    for file_name in os.listdir(os.path.join(MODEL_PATH, model_name)):
        # If the file is a .pth file
        if file_name.endswith(".pth"):
            MODELS[model_name] = os.path.join(MODEL_PATH, model_name, file_name)

In [None]:
embeddings = {}
for model_name, model_path in MODELS.items():
    embeddings[model_name] = get_embeddings_from_model(model_path)

In [None]:
from molfeat.trans.fp import FPVecTransformer
from molfeat.trans import MoleculeTransformer
threeD_method_fpvec = ["usrcat", "electroshape", "usr"]
threeD_method_moleculetransf = ["cats3d",]
fpvec_method = ["ecfp-count", "ecfp",  "estate", "erg", "rdkit", "topological", "avalon", "maccs"]
moleculetransf_method = ["scaffoldkeys", "cats2d", ]
pharmac_method = ["cats", "default", "gobbi", "pmapper"]

for name in fpvec_method:
    transformer = FPVecTransformer(kind=name, dtype=float)
    embeddings[name] = torch.tensor(transformer(df["smiles"]))

#for name in moleculetransf_method:
#    transformer = MoleculeTransformer(featurizer=name, dtype=float)
#    embeddings[name] = torch.tensor(transformer(df["smiles"]))

#for name in pharmac_method:
#    transformer = MoleculeTransformer(featurizer=Pharmacophore2D(factory=name), dtype=float)
#    embeddings[name] = torch.tensor(transformer(df["smiles"]))



In [None]:
embeddings

In [None]:
Knige_config = KNIFEArgs(
    cond_modes=16,
    marg_modes=16,
    lr=0.001,
    batch_size=32,
    device = "cpu",
    n_epochs=100,
    ff_layers=3,
)

In [None]:
def get_knife_preds(key1: str, key2:str) -> Tuple[float,float,float, List[float]]:
    x1 = embeddings[key1]
    x2 = embeddings[key2]
    knife_estimator = KNIFEEstimator(Knige_config,x1.shape[1], x2.shape[1])
    mi,m,c = knife_estimator.eval(x1, x2, record_loss = True)
    return mi, m, c, knife_estimator.recorded_loss

def get_knife_preds_plot_loss(key1:str, key2:str):
    mi, m, c, recorded_loss = get_knife_preds(key1, key2)
    plt.plot(recorded_loss)
    plt.title(f"{key1} vs {key2}\nMI: {mi:.3f}, M: {m:.3f}, C: {c:.3f}")
    plt.ylim(0, 2000)
    plt.show()


In [None]:
get_knife_preds_plot_loss( "ecfp","GROVER",)


In [None]:
get_knife_preds_plot_loss( "ecfp","AttributeMask",)


In [None]:
get_knife_preds_plot_loss( "ecfp","GraphMVP",)


In [None]:
get_knife_preds_plot_loss( "topological","GROVER",)


In [None]:
get_knife_preds_plot_loss( "topological","AttributeMask",)


In [None]:
get_knife_preds_plot_loss( "topological","GraphMVP",)


In [None]:
get_knife_preds_plot_loss( "GROVER","GraphMVP",)


In [None]:
get_knife_preds_plot_loss( "AttributeMask","GraphMVP",)
