In [None]:
import torch
import py3Dmol

from rdkit import Chem
from rdkit.Geometry import Point3D
from torch_geometric.data import Batch
from diffalign.utils.chem import set_rdmol_positions, mol_to_graph_data_obj
from diffalign.models.epsnet.diffusion import DiffAlign

### Importing model and example data

In [2]:
device = 'cuda:0'
model = DiffAlign().to(device)
model.load_state_dict(torch.load("./param/molecular_alignment.pt"))
model.eval();

In [3]:
query_mol = Chem.SDMolSupplier(f"./example/query.sdf", sanitize=False, removeHs=True)[0]
reference_mol = Chem.SDMolSupplier(f"./example/reference.sdf", sanitize=False, removeHs=True)[0]
pocket_mol = Chem.MolFromPDBFile(f"./example/pocket.pdb", sanitize=False, removeHs=True)

### Sampling a conformation of query molecule

In [None]:
query_data = mol_to_graph_data_obj(query_mol)
reference_data = mol_to_graph_data_obj(reference_mol)

mean_pos = reference_data['pos'].mean(0)
reference_data['pos'] = reference_data['pos'] - mean_pos
query_batch = Batch.from_data_list([query_data]).cuda()
reference_batch = Batch.from_data_list([reference_data]).cuda()
query_batch.pos.normal_()

conf = pocket_mol.GetConformer()
shift_vec = Point3D(
    mean_pos[0].item(),
    mean_pos[1].item(),
    mean_pos[2].item(),
)
for i in range(pocket_mol.GetNumAtoms()):
    orig = conf.GetAtomPosition(i)
    conf.SetAtomPosition(i, orig - shift_vec)


generated_pose, _ = model.DDPM_Sampling_UFF(
    query_batch=query_batch,
    reference_batch=reference_batch,
    query_mols=[query_mol],
    pocket_mols=[pocket_mol],
    uff_guidance_scale=0.01,
    uff_inner_steps=8,
)

reference_batch['pos'] = reference_batch['pos'] + mean_pos.cuda()
generated_pose = generated_pose.cpu() + mean_pos.cpu()

conf = pocket_mol.GetConformer()
for i in range(pocket_mol.GetNumAtoms()):
    orig = conf.GetAtomPosition(i)
    conf.SetAtomPosition(i, orig + shift_vec)

In [5]:
generated_mol = set_rdmol_positions(query_mol, generated_pose)

### Visualizing the alignment result

In [None]:
view = py3Dmol.view(width=500, height=500)
mblock = Chem.MolToMolBlock(reference_mol)
view.addModel(mblock, 'mol')
mblock = Chem.MolToMolBlock(generated_mol)
view.addModel(mblock, 'mol')
pocket_block = Chem.MolToPDBBlock(pocket_mol)
view.addModel(pocket_block, 'pdb')
view.setStyle({'model':0},{'stick':{'color': 'orange'}})
view.setStyle({'model':1},{'stick':{'color': 'green'}})
view.setStyle({'model': 2}, {'stick': {'color': 'white', 'opacity': 0.5}})
view.zoomTo()
view.show()