In [None]:
from emir.estimators import KNIFEEstimator, KnifeArgs

In [2]:
import os
from collections import namedtuple

from models.moleculenet_models import GNN, GNN_graphpred
from data.moleculenet_encoding import mol_to_graph_data_obj_simple
import datamol as dm
import torch
from torch_geometric.data import DataLoader
import torch_geometric.nn.pool as tgp

MODEL_PARAMS = {
    "num_layer": 5,
    "emb_dim": 300,
    "JK": "last",
    "drop_ratio": 0.5,
    "gnn_type": "gin",
}

In [3]:
df = dm.data.freesolv()
dataloader = DataLoader(
    [mol_to_graph_data_obj_simple(dm.to_mol(smiles))for smiles in df["smiles"]],
    batch_size=32,
    shuffle=False
)



In [4]:
@torch.no_grad()
def get_embeddings_from_model(
        path:str = "backbone_pretrained_models/GROVER/grover.pth",
        pooling_method = tgp.global_mean_pool
):
    embeddings = []
    molecule_model = GNN(**MODEL_PARAMS)
    molecule_model.load_state_dict(torch.load(path))
    for b in dataloader:
        embeddings.append(
            pooling_method(molecule_model(b.x, b.edge_index, b.edge_attr), b.batch)
        )
    embeddings = torch.cat(embeddings, dim=0)
    return embeddings

In [5]:
MODEL_PATH = "backbone_pretrained_models"
MODELS = {}
# For every directory in the folder
for model_name in os.listdir(MODEL_PATH):
    # For every file in the directory
    for file_name in os.listdir(os.path.join(MODEL_PATH, model_name)):
        # If the file is a .pth file
        if file_name.endswith(".pth"):
            MODELS[model_name] = os.path.join(MODEL_PATH, model_name, file_name)

In [6]:
embeddings = {}
for model_name, model_path in MODELS.items():
    embeddings[model_name] = get_embeddings_from_model(model_path)

In [7]:
from molfeat.trans.fp import FPVecTransformer
from molfeat.trans import MoleculeTransformer
threeD_method_fpvec = ["usrcat", "electroshape", "usr"]
threeD_method_moleculetransf = ["cats3d",]
fpvec_method = ["ecfp-count", "ecfp",  "estate", "erg", "rdkit", "topological", "avalon", "maccs"]
moleculetransf_method = ["scaffoldkeys", "cats2d", ]
pharmac_method = ["cats", "default", "gobbi", "pmapper"]

for name in fpvec_method:
    transformer = FPVecTransformer(kind=name, dtype=float)
    embeddings[name] = torch.tensor(transformer(df["smiles"]))

for name in moleculetransf_method:
    transformer = MoleculeTransformer(featurizer=name, dtype=float)
    embeddings[name] = torch.tensor(transformer(df["smiles"]))

for name in pharmac_method:
    transformer = MoleculeTransformer(featurizer=Pharmacophore2D(factory=name), dtype=float)
    embeddings[name] = torch.tensor(transformer(df["smiles"]))



In [9]:


def get_knife_loss(key1, key2):
    x1 = embeddings[key1]
    x2 = embeddings[key2]
    knife_estimator = KNIFEEstimator(KNIFEArgs,x1.shape[1], x2.shape[1])
    mi,m,c = knife_estimator.eval(x1, x2)
    return mi, m, c


In [None]:
get_knife_loss("ecfp-count", "GROVER")