In [7]:
#!/usr/bin/env python
# coding: utf-8

import os

from test_tube import Experiment
from pytorch_lightning import Trainer
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.utilities.seed import seed_everything
from argparse import ArgumentParser

from pathlib import Path
import numpy as np

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
import copy

from LitCoordAECopy1 import LitCoordAE
from data_utils import CODDataModule

In [19]:
checkpoint_start='logs/default/version_39/checkpoints/epoch=0-step=45.ckpt'
model = LitCoordAE.load_from_checkpoint(checkpoint_start, strict=False)

In [21]:
dataset= 'QM9'
alignment_type = 'kabsch'
seed=1334
batch_size=20
val_num_samples=10
use_X=False
use_R=False
w_reg=1e-5
refine_mom=0.99
refine_steps=0
useFF=False
dim_h=50
dim_f=100
mpnn_steps=5
checkpoint_start='logs/default/version_38/checkpoints/epoch=94-step=35624.ckpt'

data_dir = '/home/bb596/rds/hpc-work/dl4chem/'
dim_edge = 10

seed_everything(seed)

if dataset == 'QM9' :
    mpnn_steps = 3
    n_max = 9
    dim_node = 22
else :
    mpnn_steps = mpnn_steps
    n_max = 50
    dim_node = 35

data_module = CODDataModule(dataset=dataset, 
                            data_dir=data_dir, 
                            batch_size=batch_size, 
                            val_num_sample=val_num_samples)

Global seed set to 1334


In [8]:
if checkpoint_start is None :
    model = LitCoordAE(n_max=n_max,
                       dim_node=dim_node,
                       dim_edge=dim_edge, 
                       dim_h=dim_h, 
                       dim_f=dim_f, 
                       batch_size=batch_size, 
                       mpnn_steps=mpnn_steps, 
                       alignment_type=alignment_type, 
                       use_X=use_X, 
                       use_R=use_R,
                       useFF=useFF,
                       refine_steps=refine_steps, 
                       refine_mom=refine_mom, 
                       w_reg=w_reg,
                       val_num_samples=val_num_samples)

else :
    model = LitCoordAE.load_from_checkpoint(checkpoint_start, strict=False,
                                            n_max=n_max,
                                            dim_node=dim_node, 
                                            dim_edge=dim_edge,
                                            dim_h=dim_h,
                                            dim_f=dim_f,
                                            batch_size=batch_size, 
                                            mpnn_steps=mpnn_steps, 
                                            alignment_type=alignment_type, 
                                            use_X=use_X, 
                                            use_R=use_R,
                                            useFF=useFF,
                                            refine_steps=refine_steps, 
                                            refine_mom=refine_mom, 
                                            w_reg=w_reg,
                                            val_num_samples=val_num_samples)

In [22]:
data_module.setup()

In [16]:
def optimizeWithFF(mol):

    opti_mol = copy.deepcopy(mol)
    opti_mol = Chem.AddHs(opti_mol, addCoords=True)
#     print(len(opti_mol.GetConformers()))
    AllChem.MMFFOptimizeMolecule(opti_mol)
    opti_mol = Chem.RemoveHs(opti_mol)

    return opti_mol

def generateETKDGConf(mol) :
    
    original_mol = copy.deepcopy(mol)
    original_mol.RemoveConformer(0)
    original_mol = Chem.AddHs(original_mol)
    
    try : # EmbedMolecule might not find a conformation
        AllChem.EmbedMolecule(original_mol)
        original_mol = Chem.RemoveHs(original_mol)
        opti_mol = copy.deepcopy(original_mol)
        AllChem.MMFFOptimizeMolecule(opti_mol)
        opti_mol = Chem.RemoveHs(opti_mol)
    except :
        original_mol = None
        opti_mol = None

    return original_mol, opti_mol

def getRMSD(reference_mol, positions):
    """
    Args :
        reference_mol : RDKit.Molecule
        positions : Tensor(n_atom, 3)
    """

    n_atom = reference_mol.GetNumAtoms()

    test_cf = Chem.rdchem.Conformer(n_atom)
    for k in range(n_atom):
        test_cf.SetAtomPosition(k, positions[k].tolist())

    test_mol = copy.deepcopy(reference_mol)
    test_mol.RemoveConformer(0)
    test_mol.AddConformer(test_cf)

    try :
        test_mol_optimized = optimizeWithFF(test_mol)
        rmsd_opti = AllChem.AlignMol(reference_mol, test_mol_optimized)
    except :
        test_mol_optimized = None
        rmsd_opti = None
        
    rmsd_raw = AllChem.AlignMol(reference_mol, test_mol)
        
    return rmsd_raw, rmsd_opti, test_mol, test_mol_optimized

In [17]:
def generate_molecules(batch_idx, batch, gen_dir) :
    print(batch_idx)
    tensors, mols = batch
    nodes, masks, edges, proximity, pos = tensors
    
    postZ_mu, postZ_lsgms, priorZ_mu, priorZ_lsgms, X_pred, PX_pred = model(nodes, masks, edges, proximity, pos)
    
    rmsds_raw = []
    rmsds_opti = []
    for mol_idx, (ref_mol, gen_pos) in enumerate(zip(mols, PX_pred)) :
        i = batch_idx * len(mols) + mol_idx
#         print(i)
        
        rmsd_raw, rmsd_opti, test_mol, test_mol_optimized = getRMSD(ref_mol, gen_pos)
        rmsds_raw.append(rmsd_raw)
        if rmsd_opti is not None :
            rmsds_opti.append(rmsd_opti)
        w = Chem.SDWriter(f"{gen_dir}{i}.sdf")
        
        etkdg_conf, mmff_conf = generateETKDGConf(ref_mol)
        
        final_mol = copy.deepcopy(ref_mol)
        for mol in [test_mol, test_mol_optimized, etkdg_conf, mmff_conf] :
            if mol != None :
                final_mol.AddConformer(mol.GetConformer(0), assignId=True)
            
        Chem.rdMolAlign.AlignMolConformers(final_mol)
            
        for conf_id in range(len(final_mol.GetConformers())) :
            w.write(final_mol, confId=conf_id)
        
        w.close()
        
        AllChem.Compute2DCoords(final_mol)
        Draw.MolToFile(final_mol, f"{gen_dir}{i}.png") 

    rmsds_raw = np.array(rmsds_raw)
    rmsds_opti = np.array(rmsds_opti)
#     print(f"Mean RMSD raw = {rmsds_raw.mean()}")
#     print(f"Median RMSD raw = {np.median(rmsds_raw)}")
#     print(f"Mean RMSD opti = {rmsds_opti.mean()}")
#     print(f"Median RMSD opti = {np.median(rmsds_opti)}")

In [12]:
gen_dir = 'generation_sdf_cod_v2/'
Path(gen_dir).mkdir(exist_ok=True)
for batch_idx, batch in enumerate(data_module.test_dataloader()) :
    generate_molecules(batch_idx, batch, gen_dir)

0




1


KeyboardInterrupt: 

In [23]:
gen_dir = 'generation_sdf_cod_train_untrained/'
Path(gen_dir).mkdir(exist_ok=True)
for batch_idx, batch in enumerate(data_module.train_dataloader()) :
    generate_molecules(batch_idx, batch, gen_dir)

0
1
2
3


KeyboardInterrupt: 