# Conditional Generation
Minimal conditional generation workflow on arbituary reference molecules.

Note that this uses the auxiliary `shepherd-score` package for convenience. If you want to only use the packages in this repo, please reference `RUNME_conditional_generation_MOSESaq.ipynb`.

In [None]:
import os
import numpy as np
from rdkit import Chem
import torch

from shepherd import load_model
from shepherd.inference import generate

from shepherd_score.visualize import draw, draw_molecule, draw_sample
from shepherd_score.container import Molecule
from shepherd_score.conformer_generation import optimize_conformer_with_xtb, charges_from_single_point_conformer_with_xtb
from shepherd_score.conformer_generation import update_mol_coordinates, embed_conformer_from_smiles
from shepherd_score.evaluations.evaluate.pipelines import ConditionalEvalPipeline

tmp_dir = os.environ.get('TMPDIR', './')

### Load ShEPhERD
This will download the model weights if this is the first time using it.

In [None]:
model = load_model('mosesaq', device='cuda')
solvent = 'water' # this should be None for any of the GDB models

### Set-up your reference molecule
You should obtain a conformer that you would like to condition ShEPHERD on. For the sake of example, we just randomly embed a molecule from a SMILES string, below.
ShEPhERD was trained on conformers that were locally relaxed with xTB. However, if you are conditioning ShEPhERD on a conformer, you can *just* compute the partial charges with xTB, instead.

Next, we need to center the molecule. This is important otherwise ShEPhERD may fail.

In [3]:
# We will randomly generate a conformer for the sake of this example.
mol = embed_conformer_from_smiles('Fc1c(N)cc(NC(=O)Cn2c(=O)nc(C)cc(C)2)cc1')

# Optionally, locally optimize the conformer with xtb
mol, _, partial_charges = optimize_conformer_with_xtb(mol, solvent=solvent, num_cores=1, charge=Chem.GetFormalCharge(mol), temp_dir=tmp_dir)

# If you want to keep the exact conformer, you can use the following to compute partial charges
# partial_charges = charges_from_single_point_conformer_with_xtb(mol, solvent=solvent, num_cores=1, charge=Chem.GetFormalCharge(mol), temp_dir=tmp_dir)

# Center molecule to origin (required)
mol_COM = mol.GetConformer().GetPositions().mean(axis=0)
mol = update_mol_coordinates(mol, mol.GetConformer().GetPositions() - mol_COM)

#### Extract interaction profiles for reference molecule

In [8]:
# Create a Molecule object that extracts relevant interaction profiles
ref_molec = Molecule(
    mol,
    num_surf_points=model.params['dataset']['x3']['num_points'], # This should be 75
    probe_radius=model.params['dataset']['probe_radius'], # This should be 0.6
    partial_charges = partial_charges,
    pharm_multi_vector=False # ShEPhERD was trained where this setting is False
)

### Generate samples with ShEPhERD
Typically, you will want to scan over a range of `N_x1`, and `N_x4`.

In [9]:
batch_size = 10

generated_samples = generate(
    model_pl=model,
    batch_size=batch_size,
    N_x1 = ref_molec.mol.GetNumAtoms(),
    N_x4 = len(ref_molec.pharm_types),
    unconditional=False,
    inpaint_x3_pos = True,
    inpaint_x3_x = True,
    inpaint_x4_pos = True,
    inpaint_x4_direction = True,
    inpaint_x4_type = True,

    # these are the inpainting targets
    center_of_mass = np.zeros(3),
    surface = ref_molec.surf_pos,
    electrostatics = ref_molec.surf_esp,
    pharm_types = ref_molec.pharm_types,
    pharm_pos = ref_molec.pharm_ancs,
    pharm_direction = ref_molec.pharm_vecs,
    verbose=True,
    num_steps=400
)
torch.cuda.empty_cache() # clear GPU memory

100%|█████████████████████████████████████████| 399/399 [01:22<00:00,  4.86it/s]


### Evaluate generated samples

In [11]:
# Evaluate the generated samples
generated_samples_ap = [(sample['x1']['atoms'], sample['x1']['positions']) for sample in generated_samples]

ref_molec_eval = Molecule(
    mol,
    num_surf_points=400,
    probe_radius=1.2,
    partial_charges = partial_charges,
    pharm_multi_vector=False
)

cond_eval_pipeline = ConditionalEvalPipeline(
    ref_molec_eval,
    generated_samples_ap,
    condition='all',
    num_surf_points=400,
    pharm_multi_vector=False,
    solvent=solvent
)
cond_eval_pipeline.evaluate(
    num_workers=4,
    verbose=True
)
cond_evals_out_summary, cond_evals_out_detailed = cond_eval_pipeline.to_pandas()

Conditional Eval: 100%|█████████████████████████| 10/10 [00:49<00:00,  4.91s/it]


In [14]:
cond_evals_out_summary[['frac_valid', 'frac_valid_post_opt', 'frac_consistent', 'frac_unique_post_opt', 'avg_graph_diversity']]

frac_valid                   0.9
frac_valid_post_opt          0.9
frac_consistent              0.8
frac_unique_post_opt         1.0
avg_graph_diversity     0.836199
dtype: object

In [13]:
cond_evals_out_detailed.head(1)

Unnamed: 0,generated_mols,molblocks,molblocks_post_opt,strain_energies,rmsds,SA_scores,logPs,QEDs,fsp3s,SA_scores_post_opt,...,sims_surf_target_relax,sims_esp_target_relax,sims_pharm_target_relax,sims_surf_target_relax_optimal,sims_esp_target_relax_optimal,sims_pharm_target_relax_optimal,sims_surf_target_relax_esp_aligned,sims_pharm_target_relax_esp_aligned,graph_similarities,graph_similarities_post_opt
0,"([6, 7, 6, 1, 1, 1, 6, 1, 6, 6, 7, 6, 9, 6, 6,...",\n RDKit 3D\n\n 36 38 0 0 0 0...,\n RDKit 3D\n\n 36 38 0 0 0 0...,0.028085,0.226549,3.446423,1.53562,0.877774,0.266667,3.446423,...,0.850946,0.704684,0.551546,0.877846,0.722548,0.727065,0.877622,0.617406,0.158416,0.158416


### Visualize these samples in 3D

In [21]:
# Visualize the generated samples (without post-processing)
ind_to_view = 0
draw_sample(generated_samples[ind_to_view], ref_molec.mol, only_atoms=False)

<py3Dmol.view at 0x148c3fb54950>

In [19]:
# Visualize the generated sample after relaxation
#  Must be valid, otherwise it will error
view = draw_molecule(ref_molec, opacity=0.6, no_surface_points=True, opacity_features=0.7)
draw(Chem.MolFromMolBlock(cond_eval_pipeline.molblocks_post_opt[ind_to_view], removeHs=False), view=view)

<py3Dmol.view at 0x148c3fb7d390>