In [1]:
#!/usr/bin/env python
# coding: utf-8

import os
import torch

from pathlib import Path
import numpy as np

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
import copy

from LitCoordAE import LitCoordAE
from data_utils import CODDataModule

In [2]:
checkpoint_path = "/home/bb596/pytorch_dl4chem-geometry/logs/default/version_43/checkpoints/epoch=193-step=145305.ckpt"
checkpoint = torch.load(checkpoint_path, map_location='cpu')
hparams = checkpoint['hyper_parameters']

model = LitCoordAE.load_from_checkpoint(checkpoint_path, hparams=hparams, strict=False)

In [3]:
data_module = CODDataModule(dataset=hparams['dataset'])
data_module.setup()

In [4]:
def optimizeWithFF(mol):

    opti_mol = copy.deepcopy(mol)
    opti_mol = Chem.AddHs(opti_mol, addCoords=True)
#     print(len(opti_mol.GetConformers()))
    AllChem.MMFFOptimizeMolecule(opti_mol)
    opti_mol = Chem.RemoveHs(opti_mol)

    return opti_mol

def generateETKDGConf(mol) :
    
    original_mol = copy.deepcopy(mol)
    original_mol.RemoveConformer(0)
    original_mol = Chem.AddHs(original_mol)
    
    try : # EmbedMolecule might not find a conformation
        AllChem.EmbedMolecule(original_mol)
        original_mol = Chem.RemoveHs(original_mol)
        opti_mol = copy.deepcopy(original_mol)
        AllChem.MMFFOptimizeMolecule(opti_mol)
        opti_mol = Chem.RemoveHs(opti_mol)
    except :
        original_mol = None
        opti_mol = None

    return original_mol, opti_mol

def getRMSD(reference_mol, positions):
    """
    Args :
        reference_mol : RDKit.Molecule
        positions : Tensor(n_atom, 3)
    """

    n_atom = reference_mol.GetNumAtoms()

    test_cf = Chem.rdchem.Conformer(n_atom)
    for k in range(n_atom):
        test_cf.SetAtomPosition(k, positions[k].tolist())

    test_mol = copy.deepcopy(reference_mol)
    test_mol.RemoveConformer(0)
    test_mol.AddConformer(test_cf)

    try :
        test_mol_optimized = optimizeWithFF(test_mol)
        rmsd_opti = AllChem.AlignMol(reference_mol, test_mol_optimized)
    except :
        test_mol_optimized = None
        rmsd_opti = None
        
    rmsd_raw = AllChem.AlignMol(reference_mol, test_mol)
        
    return rmsd_raw, rmsd_opti, test_mol, test_mol_optimized

In [6]:
def generate_conformations(batch_idx, batch, gen_dir) :
    print(batch_idx)
    #tensors, mols = batch
    #nodes, masks, edges, proximity, pos = tensors
        
    nodes = batch.nodes
    masks = batch.masks
    edges = batch.edges
    proximity = batch.proximities
    pos = batch.positions
    mols = batch.mols
    
    postZ_mu, postZ_lsgms, priorZ_mu, priorZ_lsgms, X_pred, PX_pred = model(nodes, masks, edges, proximity, pos)
    
    rmsds_raw = []
    rmsds_opti = []
    for mol_idx, (ref_mol, gen_pos) in enumerate(zip(mols, PX_pred)) :
        i = batch_idx * len(mols) + mol_idx
#         print(i)
        
        rmsd_raw, rmsd_opti, test_mol, test_mol_optimized = getRMSD(ref_mol, gen_pos)
        rmsds_raw.append(rmsd_raw)
        if rmsd_opti is not None :
            rmsds_opti.append(rmsd_opti)
        w = Chem.SDWriter(f"{gen_dir}{i}.sdf")
        
        etkdg_conf, mmff_conf = generateETKDGConf(ref_mol)
        
        final_mol = copy.deepcopy(ref_mol)
        for mol in [test_mol, test_mol_optimized, etkdg_conf, mmff_conf] :
            if mol != None :
                final_mol.AddConformer(mol.GetConformer(0), assignId=True)
            
        Chem.rdMolAlign.AlignMolConformers(final_mol)
            
        for conf_id in range(len(final_mol.GetConformers())) :
            w.write(final_mol, confId=conf_id)
        
        w.close()
        
        AllChem.Compute2DCoords(final_mol)
        Draw.MolToFile(final_mol, f"{gen_dir}{i}.png") 

    rmsds_raw = np.array(rmsds_raw)
    rmsds_opti = np.array(rmsds_opti)
#     print(f"Mean RMSD raw = {rmsds_raw.mean()}")
#     print(f"Median RMSD raw = {np.median(rmsds_raw)}")
#     print(f"Mean RMSD opti = {rmsds_opti.mean()}")
#     print(f"Median RMSD opti = {np.median(rmsds_opti)}")

In [None]:
gen_dir = 'generation_sdf_qm9/'
Path(gen_dir).mkdir(exist_ok=True)
for batch_idx, batch in enumerate(data_module.test_dataloader()) :
    generate_conformations(batch_idx, batch, gen_dir)

0
False
1
False


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/home/bb596/.conda/envs/rdenv/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3427, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-7-100ff30f09fd>", line 4, in <module>
    generate_conformations(batch_idx, batch, gen_dir)
  File "<ipython-input-6-6fd912066ae2>", line 41, in generate_conformations
    w.close()
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/bb596/.conda/envs/rdenv/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2054, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/bb596/.conda/envs/rdenv/lib/python3.7/site-packages/IPython/core/ultratb.py", line 1101, in g

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/home/bb596/.conda/envs/rdenv/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3427, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-7-100ff30f09fd>", line 4, in <module>
    generate_conformations(batch_idx, batch, gen_dir)
  File "<ipython-input-6-6fd912066ae2>", line 41, in generate_conformations
    w.close()
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/bb596/.conda/envs/rdenv/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2054, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/bb596/.conda/envs/rdenv/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3