In [7]:
import torch
from Source.model import MolGraphNet
from Source.mol_featurizer import featurize_sdf
from torch_geometric.nn import GCNConv, Set2Set, GNNExplainer
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit import Chem

In [8]:
path_to_data = "/home/cairne/PythonProj/SmartChemDesign/mol_torch_model/Data/An_converted/Am_ML.sdf"
valuename = "logK"
path_to_model = "/home/cairne/PythonProj/SmartChemDesign/mol_torch_model/Output/Results_Am_logK_regression_2022_07_05_16_08_50/fold_1/best_model"
threshold = 0.9


In [9]:
mols = Chem.SDMolSupplier(path_to_data)

for mol_id, mol in enumerate(mols):
    output_name = f"images_0_9/explanation_{mol_id}.png"
    device = torch.device("cpu")
    dataset = featurize_sdf(path_to_data, valuenames=[valuename])
    batch = dataset[mol_id]

    model = MolGraphNet(dataset[1])
    state_dict = torch.load(path_to_model, map_location=device)
    model.load_state_dict(state_dict)

    x, edge_index = batch.x, batch.edge_index
    explainer = GNNExplainer(model, epochs=1000, return_type="regression", allow_edge_mask=True)
    node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index)
    edge_mask = edge_mask.tolist()
    edge_index = edge_index.tolist()
    edge_index_pairs = [(edge_index[0][i], edge_index[1][i], edge_mask[i]) for i in range(len(edge_index[1]))]
    hit_ats, hit_bonds = set(), set()

    for i, pair in enumerate(edge_index_pairs):
        if pair[2] > threshold:
            hit_ats.add(pair[0])
            hit_ats.add(pair[1])
            hit_bonds.add(mol.GetBondBetweenAtoms(pair[0],pair[1]).GetIdx())


    d = rdMolDraw2D.MolDraw2DCairo(500, 500) # or MolDraw2DCairo to get PNGs
    rdMolDraw2D.PrepareAndDrawMolecule(d, mol, highlightAtoms=hit_ats,
                                       highlightBonds=hit_bonds)
    d.WriteDrawingText(output_name)

Explain graph: 100%|██████████| 1000/1000 [00:07<00:00, 132.38it/s]
Explain graph: 100%|██████████| 1000/1000 [00:07<00:00, 129.30it/s]
Explain graph: 100%|██████████| 1000/1000 [00:08<00:00, 120.32it/s]
Explain graph: 100%|██████████| 1000/1000 [00:08<00:00, 116.29it/s]
Explain graph: 100%|██████████| 1000/1000 [00:07<00:00, 136.69it/s]
Explain graph: 100%|██████████| 1000/1000 [00:07<00:00, 134.53it/s]
Explain graph: 100%|██████████| 1000/1000 [00:07<00:00, 136.73it/s]
Explain graph: 100%|██████████| 1000/1000 [00:08<00:00, 121.92it/s]
Explain graph: 100%|██████████| 1000/1000 [00:07<00:00, 125.12it/s]
Explain graph: 100%|██████████| 1000/1000 [00:08<00:00, 117.89it/s]
Explain graph: 100%|██████████| 1000/1000 [00:07<00:00, 136.62it/s]
Explain graph: 100%|██████████| 1000/1000 [00:07<00:00, 135.39it/s]
Explain graph: 100%|██████████| 1000/1000 [00:08<00:00, 120.31it/s]
Explain graph: 100%|██████████| 1000/1000 [00:07<00:00, 142.23it/s]
Explain graph: 100%|██████████| 1000/1000 [00:07