In [None]:
import open3d 
from shepherd_score_utils.generate_point_cloud import (
    get_atom_coords, 
    get_atomic_vdw_radii, 
    get_molecular_surface,
    get_electrostatics,
    get_electrostatics_given_point_charges,
)
from shepherd_score_utils.pharm_utils.pharmacophore import get_pharmacophores
from shepherd_score_utils.conformer_generation import update_mol_coordinates

print('importing rdkit')
import rdkit
from rdkit.Chem import rdDetermineBonds

import numpy as np
import matplotlib.pyplot as plt

print('importing torch')
import torch
import torch_geometric
from torch_geometric.nn import radius_graph
import torch_scatter

import pickle
from copy import deepcopy
import os
import multiprocessing
from tqdm import tqdm

import sys
sys.path.insert(-1, "model/")
sys.path.insert(-1, "model/equiformer_v2")

print('importing lightning')
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger

from lightning_module import LightningModule
from datasets import HeteroDataset

import importlib

from inference import *

In [None]:
chkpt = 'shepherd_chkpts/x1x3x4_diffusion_mosesaq_20240824_submission.ckpt' # checkpoint used for evaluations in preprint
#chkpt = 'shepherd_chkpts/x1x3x4_diffusion_mosesaq_20240824_30epochs_latest.ckpt' # latest checkpoint that was trained for 2-3X longer than the original version in the preprint

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model_pl = LightningModule.load_from_checkpoint(chkpt)
params = model_pl.params
model_pl.to(device)
model_pl.model.device = device

# Conditioning on Interaction Profiles of: Natural Products

In [None]:
with open('conformers/np/molblock_charges_NPs.pkl', 'rb') as f:
    molblocks_and_charges = pickle.load(f) # len(molblocks_and_charges) == 3

# choose which natural product
index = 0 # 0, 1, 2

mol = rdkit.Chem.MolFromMolBlock(molblocks_and_charges[index][0], removeHs = False) # target natural product
charges = np.array(molblocks_and_charges[index][1]) # xTB partial charges in implicit water
display(mol)

# extracting target interaction profiles (ESP and pharmacophores)
mol_coordinates = np.array(mol.GetConformer().GetPositions())
mol_coordinates = mol_coordinates - np.mean(mol_coordinates, axis = 0)
mol = update_mol_coordinates(mol, mol_coordinates)

# conditional targets
centers = mol.GetConformer().GetPositions()
radii = get_atomic_vdw_radii(mol)
surface = get_molecular_surface(
    centers, 
    radii, 
    params['dataset']['x3']['num_points'], 
    probe_radius = params['dataset']['probe_radius'],
    num_samples_per_atom = 20,
)

pharm_types, pharm_pos, pharm_direction = get_pharmacophores(
    mol,
    multi_vector = params['dataset']['x4']['multivectors'],
    check_access = params['dataset']['x4']['check_accessibility'],
)

electrostatics = get_electrostatics_given_point_charges(
    charges, centers, surface,
)

# Conditioning on Interaction Profiles of: PDB Ligands

In [None]:
with open('conformers/pdb/molblock_charges_pdb_lowestenergy.pkl', 'rb') as f:
    molblocks_and_charges = pickle.load(f)

# choose which PDB ligand
index = 6 # 0, 1, 2, 3, 4, 5, 6

mol = rdkit.Chem.MolFromMolBlock(molblocks_and_charges[index][0], removeHs = False) # target natural product
charges = np.array(molblocks_and_charges[index][1]) # xTB partial charges in implicit water
display(mol)

# extracting target interaction profiles (ESP and pharmacophores)
mol_coordinates = np.array(mol.GetConformer().GetPositions())
mol_coordinates = mol_coordinates - np.mean(mol_coordinates, axis = 0)
mol = update_mol_coordinates(mol, mol_coordinates)

# conditional targets
centers = mol.GetConformer().GetPositions()
radii = get_atomic_vdw_radii(mol)
surface = get_molecular_surface(
    centers, 
    radii, 
    params['dataset']['x3']['num_points'], 
    probe_radius = params['dataset']['probe_radius'],
    num_samples_per_atom = 20,
)

pharm_types, pharm_pos, pharm_direction = get_pharmacophores(
    mol,
    multi_vector = params['dataset']['x4']['multivectors'],
    check_access = params['dataset']['x4']['check_accessibility'],
)

electrostatics = get_electrostatics_given_point_charges(
    charges, centers, surface,
)

# Conditioning on Interaction Profiles of: Overlapping Fragments from Fragment Screen

In [None]:
with open('conformers/fragment_merging/fragment_merge_condition.pickle', 'rb') as f:
    fragment_merge_features = pickle.load(f)
COM = fragment_merge_features['x3']['positions'].mean(0)
fragment_merge_features['x2']['positions'] = fragment_merge_features['x2']['positions'] - COM
fragment_merge_features['x3']['positions'] = fragment_merge_features['x3']['positions'] - COM
fragment_merge_features['x4']['positions'] = fragment_merge_features['x4']['positions'] - COM

# conditional targets
surface = deepcopy(fragment_merge_features['x3']['positions'])
electrostatics = deepcopy(fragment_merge_features['x3']['charges'])
pharm_types = deepcopy(fragment_merge_features['x4']['types'])
pharm_pos = deepcopy(fragment_merge_features['x4']['positions'])
pharm_direction = deepcopy(fragment_merge_features['x4']['directions'])

# Running conditional generation via inpainting

In [None]:
n_atoms = 70
batch_size = 5
num_pharmacophores = len(pharm_types) # must equal pharm_pos.shape[0] if inpainting

In [None]:
generated_samples = inference_sample(
    model_pl,
    batch_size = batch_size,
    
    N_x1 = n_atoms,
    N_x4 = num_pharmacophores,
    
    unconditional = False,
    
    prior_noise_scale = 1.0,
    denoising_noise_scale = 1.0,
    
    inject_noise_at_ts = [],
    inject_noise_scales = [],    
    
    harmonize = False,
    harmonize_ts = [],
    harmonize_jumps = [],
    
    
    # all the below options are only relevant if unconditional is False
    
    inpaint_x2_pos = False, # note that x2 is implicitly modeled via x3
    
    inpaint_x3_pos = True,
    inpaint_x3_x = True,
    
    inpaint_x4_pos = True,
    inpaint_x4_direction = True,
    inpaint_x4_type = True,
    
    stop_inpainting_at_time_x2 = 0.0,
    add_noise_to_inpainted_x2_pos = 0.0,
    
    stop_inpainting_at_time_x3 = 0.0,
    add_noise_to_inpainted_x3_pos = 0.0,
    add_noise_to_inpainted_x3_x = 0.0,
    
    stop_inpainting_at_time_x4 = 0.0,
    add_noise_to_inpainted_x4_pos = 0.0,
    add_noise_to_inpainted_x4_direction = 0.0,
    add_noise_to_inpainted_x4_type = 0.0,
    
    # these are the inpainting targets
    center_of_mass = np.zeros(3), # center of mass of x1; already centered to zero above
    surface = surface,
    electrostatics = electrostatics,
    pharm_types = pharm_types,
    pharm_pos = pharm_pos,
    pharm_direction = pharm_direction,
)

In [None]:
len(generated_samples) # == batch_size

In [None]:
generated_samples[0]['x1']['atoms']

In [None]:
generated_samples[0]['x1']['positions']

In [None]:
# quick visualization of generated samples
# full analyses, including extensive validity checks, can be performed by following https://github.com/coleygroup/shepherd-score

for b,sample_dict in enumerate(generated_samples):
    
    xyz = '' 
    
    x_ = sample_dict['x1']['atoms']
    pos_ = sample_dict['x1']['positions']
    
    xyz += f'{len(x_)}\n{b+1}\n'
    for a in range(len(x_)):
        atomic_number_ = int(x_[a])
        position_ = pos_[a]
        
        xyz+= f'{rdkit.Chem.Atom(atomic_number_).GetSymbol()} {str(position_[0].round(3))} {str(position_[1].round(3))} {str(position_[2].round(3))}\n'
    xyz+= '\n'
    
    try:
        mol_ = rdkit.Chem.MolFromXYZBlock(xyz)
    except Exception as e:
        mol_ = None
        print(f'invalid molecule: {e}')
        continue
    
    try:
        for c in [0, 1, -1, 2, -2]:
            mol__ = deepcopy(mol_)
            try:
                rdkit.Chem.rdDetermineBonds.DetermineBonds(mol__, charge = c, embedChiral = True)
            except:
                continue
            if mol__ is not None:
                print(c)
                break 
    except Exception as e:
        mol_ = None
        print(f'invalid molecule: {e}')
        continue
    
    mol_ = mol__
    try:
        assert sum([a.GetNumRadicalElectrons() for a in mol_.GetAtoms()]) == 0, 'has radical electrons'
        mol_.UpdatePropertyCache()
        rdkit.Chem.GetSymmSSSR(mol_)
        
    except Exception as e:
        mol_ = None
        print(f'invalid molecule: {e}')
        continue

    display(rdkit.Chem.MolFromSmiles(rdkit.Chem.MolToSmiles(mol_)))
    
    continue