In [1]:
import os
import argparse
import pickle
import yaml
import torch
from glob import glob
from tqdm import tqdm
from easydict import EasyDict
from rdkit.Chem import AllChem

%load_ext autoreload
%autoreload 2

os.chdir('..')
os.getcwd()

from models.epsnet import *
from utils.datasets import *
from utils.transforms import *
from utils.misc import *
device = 'cuda:0'
torch.set_printoptions(precision=2, sci_mode=False)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ["CUDA_LAUNCH_BLOCKING"]="1"

In [48]:
# init model
config_path = 'configs/DiffAlign.yml'
with open(config_path, 'r') as f:
    config = EasyDict(yaml.safe_load(f))
model = get_model(config.model).to(device)

In [68]:
data_path = 'cross_set_v2/medium.pkl'
val_set = ConformationDataset(data_path, transform=None)

In [69]:
from utils.dataloader import DataLoader
val_tmp = DataLoader(val_set, batch_size=1, shuffle=False)
val_it = iter(val_tmp)
batch = next(val_it)
batch = next(val_it)

In [73]:
index = 0
mol = rdkit.Chem.MolFromSmiles(batch[index]['smiles'])
mol = rdkit.Chem.rdmolops.AddHs(mol)
AllChem.EmbedMultipleConfs(mol)
init_mol = set_rdmol_positions(mol, batch[index]['pos'])

mblock = Chem.MolToMolBlock(init_mol)
view = py3Dmol.view(width=500, height=500)
view.addModel(mblock, 'mol')
view.setStyle({'stick':{}})
view.zoomTo()
view.show()


mol2 = rdkit.Chem.MolFromSmiles(batch[index]['smiles_r'])
mol2 = rdkit.Chem.rdmolops.AddHs(mol2)
AllChem.EmbedMultipleConfs(mol2)
out_mol = set_rdmol_positions(mol2, batch[index]['pos_r'])#[ligand_batch.ptr[index]:ligand_batch.ptr[index+1]])

mblock = Chem.MolToMolBlock(out_mol)
view = py3Dmol.view(width=500, height=500)
view.addModel(mblock, 'mol')
view.setStyle({'stick':{}})
view.zoomTo()
view.show()

In [6]:
batch = batch.to(device)

In [7]:
import py3Dmol

In [8]:
def show(mol):
  mblock = Chem.MolToMolBlock(mol)
  view = py3Dmol.view(width=500, height=500)
  view.addModel(mblock, 'mol')
  view.setStyle({'stick':{}})
  view.zoomTo()
  view.show()

In [58]:
model.load_state_dict(torch.load("./param/non_self_align/84.pt"))

<All keys matched successfully>

In [12]:
model.load_state_dict(torch.load("./param/self_align_drugs/3.pt"))

<All keys matched successfully>

In [59]:
# inference
n_steps = 1000 
keys = ["atom_type_r", "edge_index_r", "edge_type_r", "pos_r", "num_r", "smiles_r", "rdmol_r"]

batch1 = batch.to_data_list()
batch2 = copy.deepcopy(batch1)
for i in range(len(batch1)):
    for key in keys:
        del batch2[i][key[:-2]]
        batch2[i][key[:-2]] = batch1[i][key]
        del batch1[i][key]
        del batch2[i][key]
ligand_batch = Batch.from_data_list(batch1)
template_batch = Batch.from_data_list(batch2)
model.eval()

ligand_batch.pos.normal_()

pos_gen, pos_gen_traj = model.langevin_dynamics_sample(
    ligand_batch=ligand_batch,
    template_batch=template_batch,
    extend_order=True, # Done in transforms.
    n_steps=n_steps,
    step_lr=1e-6,
    w_global=0.9,
    global_start_sigma=1.0,
    clip=1000.0,
    clip_local=None,
    sampling_type='generalized',
    #sampling_type='generalized',
    eta=1.0
)
pos_gen = pos_gen.cpu()

sample: 0it [00:00, ?it/s]

sample: 999it [00:21, 45.47it/s]


In [11]:
from rdkit.Chem import AllChem

In [12]:
from utils.chem import *

In [63]:
torch.nn.Embedding(1000,128)(torch.tensor(1000))

IndexError: index out of range in self

In [60]:
index = 0
mol = rdkit.Chem.MolFromSmiles(batch[index]['smiles'])
mol = rdkit.Chem.rdmolops.AddHs(mol)
AllChem.EmbedMultipleConfs(mol)
init_mol = set_rdmol_positions(mol, batch[index]['pos'])

mblock = Chem.MolToMolBlock(init_mol)
view = py3Dmol.view(width=500, height=500)
view.addModel(mblock, 'mol')
view.setStyle({'stick':{}})
view.zoomTo()
view.show()


mol2 = rdkit.Chem.MolFromSmiles(ligand_batch[index]['smiles'])
mol2 = rdkit.Chem.rdmolops.AddHs(mol2)
AllChem.EmbedMultipleConfs(mol2)
out_mol = set_rdmol_positions(mol2, pos_gen)#[ligand_batch.ptr[index]:ligand_batch.ptr[index+1]])

mblock = Chem.MolToMolBlock(out_mol)
view = py3Dmol.view(width=500, height=500)
view.addModel(mblock, 'mol')
view.setStyle({'stick':{}})
view.zoomTo()
view.show()

In [22]:
init_mol = batch.rdmol[0]
out_mol = set_rdmol_positions(init_mol, pos_gen[:19])
mblock = Chem.MolToMolBlock(init_mol)
view = py3Dmol.view(width=500, height=500)
view.addModel(mblock, 'mol')
mblock = Chem.MolToMolBlock(out_mol)
view.addModel(mblock, 'mol')
view.setStyle({'stick':{}})
view.zoomTo()
view.show()

RuntimeError: Bad pickle format: bad endian ID or invalid file format

In [25]:
i = 11
init_mol = batch.rdmol[i]
out_mol = set_rdmol_positions(init_mol, pos_gen[16*i:16*(i+1)])
mblock = Chem.MolToMolBlock(init_mol)
view = py3Dmol.view(width=500, height=500)
view.addModel(mblock, 'mol')
mblock = Chem.MolToMolBlock(out_mol)
view.addModel(mblock, 'mol')
view.setStyle({'stick':{}})
view.zoomTo()
view.show()

In [9]:
from utils.chem import *
import py3Dmol
i = 1
init_mol = batch.rdmol[i]
out_mol = set_rdmol_positions(init_mol, pos_gen[16*i:16*(i+1)])

print('Initial molecule')
show(init_mol)
print('Generated molecule')
show(out_mol)


Initial molecule


NameError: name 'show' is not defined

In [None]:
batch[5]

Data(atom_type=[16], boltzmannweight=[1], edge_index=[2, 30], edge_type=[30], idx=[1], num_nodes_per_graph=[1], nx=, pos=[16, 3], rdmol=<rdkit.Chem.rdchem.Mol object at 0x7f005d252da0>, smiles="O=CCC#CCCO", totalenergy=[1])

In [None]:
mol_num = 11
start = 16*mol_num
atom_num = 16

init_mol = batch.rdmol[mol_num]
out_mol = set_rdmol_positions(init_mol, pos_gen[start:start+atom_num])
mblock = Chem.MolToMolBlock(init_mol)
view = py3Dmol.view(width=500, height=500)
view.addModel(mblock, 'mol')
mblock = Chem.MolToMolBlock(out_mol)
view.addModel(mblock, 'mol')
view.setStyle({'stick':{}})
view.zoomTo()
view.show()

In [9]:
from utils.chem import *
import py3Dmol

mol_num = 5
start = 16*mol_num
atom_num = 16

init_mol = batch.rdmol[mol_num]
out_mol = set_rdmol_positions(init_mol, pos_gen[start:start+atom_num])

print('Initial molecule')
show(init_mol)
print('Generated molecule')
show(out_mol)

Initial molecule


NameError: name 'show' is not defined

In [None]:
init_mol = batch.rdmol[0]
out_mol = set_rdmol_positions(init_mol, pos_gen[:16])
mblock = Chem.MolToMolBlock(init_mol)
view = py3Dmol.view(width=500, height=500)
view.addModel(mblock, 'mol')
mblock = Chem.MolToMolBlock(out_mol)
view.addModel(mblock, 'mol')
view.setStyle({'stick':{}})
view.zoomTo()
view.show()

In [19]:
# print traj
i = 0
traj_len = len(pos_gen_traj)
for i in np.linspace(0, traj_len-1, 10).astype(int):
  pos = pos_gen_traj[i][16*i:16*(i+1)].cpu()
  out_mol = set_rdmol_positions(init_mol, pos)
  print(f'step {i}')
  show(out_mol)

step 0


step 555


step 1110


step 1666


step 2221


step 2777


step 3332


step 3888


step 4443


step 4999


In [None]:
# print traj
traj_len = len(pos_gen_traj)
for i in np.linspace(0, traj_len-1, 10).astype(int):
  pos = pos_gen_traj[i][16:].cpu()
  out_mol = set_rdmol_positions(init_mol, pos)
  print(f'step {i}')
  show(out_mol)

step 0


step 555


step 1110


step 1666


step 2221


step 2777


step 3332


step 3888


step 4443


step 4999


In [None]:
writer = Chem.SDWriter('temp/out.sdf')
for cid in range(out_mol.GetNumConformers()):
    writer.write(out_mol, confId=cid)

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem
import py3Dmol

def show(smi, style='stick'):
    mol = Chem.MolFromSmiles(smi)
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol)
    AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
    mblock = Chem.MolToMolBlock(mol)

    view = py3Dmol.view(width=200, height=200)
    view.addModel(mblock, 'mol')
    view.setStyle({style:{}})
    view.zoomTo()
    view.show()
    
# example
show('CC')  # or 'P'