In [1]:
#!/usr/bin/env python
# coding: utf-8

import os

from test_tube import Experiment
from pytorch_lightning import Trainer
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.utilities.seed import seed_everything
from argparse import ArgumentParser

from LitCoordAE import LitCoordAE
from data_utils import CODDataModule

In [1]:
parser = ArgumentParser(description='Train network')

dataset= 'QM9'
alignment_type = 'kabsch'
seed=1334
batch_size=20
val_num_samples=10
use_X=False
use_R=False
w_reg=1e-5
refine_mom=0.99
refine_steps=0
useFF=False
dim_h=50
dim_f=100
mpnn_steps=5
checkpoint_start='logs/default/version_36/checkpoints/epoch=87-step=65911.ckpt'

data_dir = '/home/bb596/rds/hpc-work/dl4chem/'
dim_edge = 10

seed_everything(seed)

if dataset == 'QM9' :
    mpnn_steps = 3
    n_max = 9
    dim_node = 22
else :
    mpnn_steps = mpnn_steps
    n_max = 50
    dim_node = 35

data_module = CODDataModule(dataset=dataset, 
                            data_dir=data_dir, 
                            batch_size=batch_size, 
                            val_num_sample=val_num_samples)

if checkpoint_start is None :
    model = LitCoordAE(n_max=n_max,
                       dim_node=dim_node,
                       dim_edge=dim_edge, 
                       dim_h=dim_h, 
                       dim_f=dim_f, 
                       batch_size=batch_size, 
                       mpnn_steps=mpnn_steps, 
                       alignment_type=alignment_type, 
                       use_X=use_X, 
                       use_R=use_R,
                       useFF=useFF,
                       refine_steps=refine_steps, 
                       refine_mom=refine_mom, 
                       w_reg=w_reg,
                       val_num_samples=val_num_samples)

else :
    model = LitCoordAE.load_from_checkpoint(checkpoint_start, strict=False,
                                            n_max=n_max,
                                            dim_node=dim_node, 
                                            dim_edge=dim_edge,
                                            dim_h=dim_h,
                                            dim_f=dim_f,
                                            batch_size=batch_size, 
                                            mpnn_steps=mpnn_steps, 
                                            alignment_type=alignment_type, 
                                            use_X=use_X, 
                                            use_R=use_R,
                                            useFF=useFF,
                                            refine_steps=refine_steps, 
                                            refine_mom=refine_mom, 
                                            w_reg=w_reg,
                                            val_num_samples=val_num_samples)

Global seed set to 1334


In [2]:
from rdkit import Chem
from rdkit.Chem import AllChem
import copy
def getRMSD(reference_mol, positions):
    """
    Args :
        reference_mol : RDKit.Molecule
        positions : Tensor(n_atom, 3)
    """

    def optimizeWithFF(mol):

        mol = Chem.AddHs(mol, addCoords=True)
        AllChem.MMFFOptimizeMolecule(mol)
        mol = Chem.RemoveHs(mol)

        return mol

    n_atom = reference_mol.GetNumAtoms()

    test_cf = Chem.rdchem.Conformer(n_atom)
    for k in range(n_atom):
        test_cf.SetAtomPosition(k, positions[k].tolist())

    test_mol = copy.deepcopy(reference_mol)
    test_mol.RemoveConformer(0)
    test_mol.AddConformer(test_cf)

    useFF = True
    if useFF:
        try:
            rmsd = AllChem.AlignMol(reference_mol, optimizeWithFF(test_mol))
        except:
            rmsd = AllChem.AlignMol(reference_mol, test_mol)
    else:
        rmsd = AllChem.AlignMol(reference_mol, test_mol)

    Chem.rdmolfiles.MolToXYZFile(test_mol, 'generated.xyz')
    test_mol_optimized = optimizeWithFF(test_mol)
    Chem.rdmolfiles.MolToXYZFile(test_mol_optimized, 'generated_optimized.xyz')
    Chem.rdmolfiles.MolToXYZFile(reference_mol, 'reference.xyz')
        
    return rmsd, test_mol, test_mol_optimized

In [3]:
from pathlib import Path
import numpy as np
Path('generation_sdf').mkdir(exist_ok=True)

In [83]:
i = 0
for batch_idx, batch in enumerate(data_module.test_dataloader()) :
    
    tensors, mols = batch
    nodes, masks, edges, proximity, pos = tensors
    
    postZ_mu, postZ_lsgms, priorZ_mu, priorZ_lsgms, X_pred, PX_pred = model(nodes, masks, edges, proximity, pos)
    
    rmsds = []
    for ref_mol, gen_pos in zip(mols, PX_pred) :
        rmsd, test_mol, test_mol_optimized = getRMSD(ref_mol, gen_pos)
        rmsds.append(rmsd)
        w = Chem.SDWriter(f"generation_sdf/{i}.sdf")
        for mol in [ref_mol, test_mol, test_mol_optimized] :
            w.write(mol)
        w.close()

        i += 1
    print(np.array(rmsds).mean())

0.7928741822138056
0.8350640102188198
0.6921377702962437
0.7223237322858075
0.7069256578817102
0.8218796574815466
0.6389606926111497
0.6400068035477975
0.7235435773067662
0.8186656797353447
0.815113600782008
0.6072948877372666
0.9008777025891801
0.747124362807978
0.7969042109665205
0.7948285364541435
0.6657117793981684
0.5920232480892278
0.8640214162547597
0.6449524230612829
0.6002006477656439
0.8430923307214853
0.5639872788872685
0.5354998319158356
0.7060537740215751
0.6589184351826882
0.733029666219813
0.6756117887200812
0.6869158024682358
0.8223530295258668
0.8485074976030363
0.6470974231042618
0.7671595701901415
0.7112414809173606
0.7822185086192424
0.8236946275919991
0.6842986441979442
0.8253521824549604
0.8853862919005803
0.7542792028778635
0.5767185701401213
0.6921716637023978
0.7532442376700521
0.646823416485966
0.7889974381754632
0.8351487224176953
0.7681429094078714
0.7668651897654689
0.7392392744315779
0.6053690473875201
0.8078707599440396
0.7110960937825126
0.73142792180336

RuntimeError: Invariant Violation
	no eligible neighbors for chiral center
	Violation occurred on line 215 in file Code/GraphMol/FileParsers/MolFileStereochem.cpp
	Failed Expression: nbrScores.size()
	RDKIT: 2020.09.1
	BOOST: 1_73
