In [None]:
import open3d 
from shepherd_score_utils.generate_point_cloud import (
    get_atom_coords, 
    get_atomic_vdw_radii, 
    get_molecular_surface,
    get_electrostatics,
    get_electrostatics_given_point_charges,
)
from shepherd_score_utils.pharm_utils.pharmacophore import get_pharmacophores
from shepherd_score_utils.conformer_generation import update_mol_coordinates

print('importing rdkit')
import rdkit
from rdkit.Chem import rdDetermineBonds

import numpy as np
import matplotlib.pyplot as plt

print('importing torch')
import torch
import torch_geometric
from torch_geometric.nn import radius_graph
import torch_scatter

import pickle
from copy import deepcopy
import os
import multiprocessing
from tqdm import tqdm

import sys
sys.path.insert(-1, "model/")
sys.path.insert(-1, "model/equiformer_v2")

print('importing lightning')
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger

from lightning_module import LightningModule
from datasets import HeteroDataset

import importlib

from inference import *

## P(x1,x2), P(x1,x3), P(x1,x4) models trained on ShEPhERD-GDB-17

In [None]:
# pick one
#chkpt = 'shepherd_chkpts/x1x2_diffusion_gdb17_20240824_submission.ckpt'
#chkpt = 'shepherd_chkpts/x1x3_diffusion_gdb17_20240824_submission.ckpt'
chkpt = 'shepherd_chkpts/x1x4_diffusion_gdb17_20240824_submission.ckpt'

## P(x1,x3,x4) model trained on ShEPhERD-MOSES-aq

In [None]:
chkpt = 'shepherd_chkpts/x1x3x4_diffusion_mosesaq_20240824_submission.ckpt' # checkpoint used for evaluations in preprint
#chkpt = 'shepherd_chkpts/x1x3x4_diffusion_mosesaq_20240824_30epochs_latest.ckpt' # latest checkpoint that was trained for 2-3X longer than the original version in the preprint

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model_pl = LightningModule.load_from_checkpoint(chkpt)
params = model_pl.params
model_pl.to(device)
model_pl.model.device = device

In [None]:
batch_size = 10
n_atoms = 60
num_pharmacophores = 10 # set to 5 (just a dummy value) if using a model that does not model x4

In [None]:
# use to break symmetry during unconditional generation
T = params['noise_schedules']['x1']['T'] # T == 400
inject_noise_at_ts = list(np.arange(130, 80, -1)) # [150]
inject_noise_scales = [1.0] * len(inject_noise_at_ts)
harmonize = True
harmonize_ts = [80]
harmonize_jumps = [20]

# to NOT break symmetry (expect spherical molecules with low diversity), use this instead:
"""
inject_noise_at_ts = []
inject_noise_scales = []
harmonize = False
harmonize_ts = []
harmonize_jumps = []
"""


generated_samples = inference_sample(
    model_pl,
    batch_size = batch_size,
    
    N_x1 = n_atoms,
    N_x4 = num_pharmacophores, 
    
    unconditional = True,
    
    prior_noise_scale = 1.0,
    denoising_noise_scale = 1.0,
    
    # use to break symmetry during unconditional generation
    inject_noise_at_ts = inject_noise_at_ts,
    inject_noise_scales = inject_noise_scales, 
    harmonize = harmonize,
    harmonize_ts = harmonize_ts,
    harmonize_jumps = harmonize_jumps,
)

In [None]:
print(len(generated_samples))

In [None]:
print(generated_samples[0]['x1']['atoms']) # atomic numbers

In [None]:
print(generated_samples[0]['x1']['positions']) # atomic coordinates

In [None]:
print(generated_samples[0]['x2']['positions']) # shape surface point coordinates

In [None]:
print(generated_samples[0]['x3']['positions']) # ESP surface point coordinates

In [None]:
print(generated_samples[0]['x3']['charges']) # ESP values

In [None]:
print(generated_samples[0]['x4']['types']) # pharmacophore types

In [None]:
print(generated_samples[0]['x4']['positions']) # pharmacophore positions

In [None]:
print(generated_samples[0]['x4']['directions']) # pharmacophore directions

In [None]:
# quick visualization of generated samples
# full analyses, including extensive validity checks, can be performed by following https://github.com/coleygroup/shepherd-score

for b,sample_dict in enumerate(generated_samples):
    
    xyz = '' 
    
    x_ = sample_dict['x1']['atoms']
    pos_ = sample_dict['x1']['positions']
    
    xyz += f'{len(x_)}\n{b+1}\n'
    for a in range(len(x_)):
        atomic_number_ = int(x_[a])
        position_ = pos_[a]
        
        xyz+= f'{rdkit.Chem.Atom(atomic_number_).GetSymbol()} {str(position_[0].round(3))} {str(position_[1].round(3))} {str(position_[2].round(3))}\n'
    xyz+= '\n'
    
    try:
        mol_ = rdkit.Chem.MolFromXYZBlock(xyz)
    except Exception as e:
        mol_ = None
        print(f'invalid molecule: {e}')
        continue
    
    try:
        for c in [0, 1, -1, 2, -2]:
            mol__ = deepcopy(mol_)
            try:
                rdkit.Chem.rdDetermineBonds.DetermineBonds(mol__, charge = c, embedChiral = True)
            except:
                continue
            if mol__ is not None:
                print(c)
                break 
    except Exception as e:
        mol_ = None
        print(f'invalid molecule: {e}')
        continue
    
    mol_ = mol__
    try:
        assert sum([a.GetNumRadicalElectrons() for a in mol_.GetAtoms()]) == 0, 'has radical electrons'
        mol_.UpdatePropertyCache()
        rdkit.Chem.GetSymmSSSR(mol_)
        
    except Exception as e:
        mol_ = None
        print(f'invalid molecule: {e}')
        continue

    display(rdkit.Chem.MolFromSmiles(rdkit.Chem.MolToSmiles(mol_)))
    
    continue