In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
import sys 
sys.path.append(".")
import math
import pickle as pkl 
import numpy as np
from tqdm import tqdm

print("We are now running Python in: ", sys.path)

## Preparation (nets, constants, params and functional utils)

In [None]:
import jax 
import jax.numpy as jnp

from functools import partial
from cybertron.common.config_load import load_config

#### load nets
from train.train import MolEditScoreNet
from cybertron.model.molct_plus import MolCT_Plus
from cybertron.readout import GFNReadout

from train.utils import set_dropout_rate_config
from jax.sharding import PositionalSharding

def _sharding(input, shards):

    n_device = shards.shape[0]
    if isinstance(input, (np.ndarray, jax.Array)):
        _shape = [n_device, ] + [1 for _ in range(input.ndim - 1)]
        return jax.device_put(input, shards.reshape(_shape))
    elif input is None:
        return jax.device_put(input, shards)
    else:
        raise TypeError(f"Invalid input: {input}")

from inference.inference import DPM_3_inference, Langevin_inference, DPM_pp_2S_inference

In [None]:
NDEVICES = 1
NATOMS = 64
SHARDING = True #### you can use multiple devices
if SHARDING:
    NDEVICES = len(jax.devices())
    print("{} DEVICES detected: {}".format(NDEVICES, jax.devices()))

def split_rngs(rng_key, shape):
    size = np.prod(shape)
    rng_keys = jax.random.split(rng_key, size + 1)
    return rng_keys[:-1].reshape(shape + (-1,)), rng_keys[-1]

rng_key = jax.random.PRNGKey(8888) #### set your random seed here
np.random.seed(7777)

In [None]:
##### initialize models (structure diffusion model)
encoder_config = load_config("config/molct_plus.yaml")
gfn_config = load_config("config/gfn.yaml")
gfn_config.settings.n_interactions = 4

modules = {
    "encoder": {"module": MolCT_Plus, 
                "args": {"config": encoder_config}},
    "gfn": {"module": GFNReadout, 
            "args": {"config": gfn_config}}
}

##### load params
load_ckpt_paths = ['./params/ZINC_3m/structure_model/moledit_params_track1.pkl', 
                   './params/ZINC_3m/structure_model/moledit_params_track2.pkl',
                   './params/ZINC_3m/structure_model/moledit_params_track3.pkl'] 
noise_thresholds = [0.35, 1.95]

params = []
for path in load_ckpt_paths:
    with open(path, 'rb') as f: 
        params.append(pkl.load(f))
    
if SHARDING:
    ##### replicate params
    global_sharding = PositionalSharding(jax.devices()).reshape(NDEVICES, 1)
    params = jax.device_put(params, global_sharding.replicate())

for k, v in modules.items():
    modules[k]['args']['config'] = \
        set_dropout_rate_config(modules[k]['args']['config'], 0.0)
    modules[k]["module"] = v["module"](**v["args"])
    modules[k]["callable_fn"] = [] 
    for param in params:
        partial_params = {"params": param["params"]['score_net'].pop(k)}
        modules[k]["callable_fn"].append(partial(modules[k]["module"].apply, partial_params))

moledit_scorenets = [MolEditScoreNet(
        encoder=modules['encoder']['callable_fn'][k],
        gfn=modules['gfn']['callable_fn'][k],
    ) for k in range(len(load_ckpt_paths))]

In [None]:
##### initialize models (constituents model)
from config.transformer_config import transformer_config
from transformer.model import Transformer, TransformerConfig

with open('params/ZINC_3m/constituents_model/constituents_vocab.pkl', 'rb') as f:
    constituent_vocab_list = pkl.load(f)
NCONSTITUENTS = len(constituent_vocab_list) # 38
NRG_TOKENS = 3 # seq_len = 38 + 3
SEQ_LEN = NCONSTITUENTS + NRG_TOKENS

NRG_VOCABS = 11
transformer_config.deterministic = True
transformer_config.dtype = jnp.float32
transformer = Transformer(
    config=TransformerConfig(
            **{
                **transformer_config,
                "vocab_size": NATOMS + NRG_VOCABS + 1, 
                "output_vocab_size": NATOMS + NRG_VOCABS + 1}, )
)

##### load params
with open("params/ZINC_3m/constituents_model/moledit_params.pkl", "rb") as f:
    params = pkl.load(f)
    params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)

if SHARDING:
    ##### replicate params
    global_sharding = PositionalSharding(jax.devices()).reshape(NDEVICES, 1)
    params = jax.device_put(params, global_sharding.replicate())

def top_p_sampling(logits, rng_key, p=0.9):
    sorted_indices = jnp.argsort(logits)
    sorted_logits = logits[sorted_indices]
    sorted_probs = jax.nn.softmax(sorted_logits)
    cum_probs = jnp.cumsum(sorted_probs)
    invalid_mask = cum_probs < (1-p)
    
    rng_key, sample_key = jax.random.split(rng_key)
    sampled_token = jax.random.categorical(sample_key, sorted_logits+invalid_mask.astype(jnp.float32)*(-1e5))
    
    return sorted_indices[sampled_token], rng_key 
        
##### prepare functions, jit & vmap
jitted_logits_fn = jax.jit(transformer.apply)
top_p_sampling_fn = jax.vmap(jax.vmap(jax.jit(partial(top_p_sampling, p=0.9))))

In [None]:
from rdkit import Chem

def SMILES_to_constituents(smi):
    mol = Chem.MolFromSmiles(smi)
    mol = Chem.RemoveAllHs(mol)
    atomic_numbers = np.array([atom.GetAtomicNum() for atom in mol.GetAtoms()], dtype=np.uint8)
    hydrogen_numbers = np.array([atom.GetTotalNumHs() for atom in mol.GetAtoms()], dtype=np.uint8)
    hybridizations = np.array([atom.GetHybridization() for atom in mol.GetAtoms()], dtype=np.uint8)

    bond_ids = np.array([(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()], dtype=np.uint8)
    bond_types = np.array([int(bond.GetBondType()) for bond in mol.GetBonds()], dtype=np.uint8)
    
    topology = {i: {} for i in range(len(atomic_numbers))}
    for (atom_i, atom_j), bond_type in zip(bond_ids, bond_types):
        topology[atom_i][atom_j] = topology[atom_j][atom_i] = bond_type
    
    constituents_dict = {
        "atomic_numbers": atomic_numbers,
        "hydrogen_numbers": hydrogen_numbers,
        "hybridizations": hybridizations,
        "bonds": topology,
    }

    return constituents_dict

def RDMol_to_constituents_and_structure(mol):
    mol = Chem.RemoveAllHs(mol)
    atomic_numbers = np.array([atom.GetAtomicNum() for atom in mol.GetAtoms()], dtype=np.uint8)
    hydrogen_numbers = np.array([atom.GetTotalNumHs() for atom in mol.GetAtoms()], dtype=np.uint8)
    hybridizations = np.array([atom.GetHybridization() for atom in mol.GetAtoms()], dtype=np.uint8)

    bond_ids = np.array([(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()], dtype=np.uint8)
    bond_types = np.array([int(bond.GetBondType()) for bond in mol.GetBonds()], dtype=np.uint8)
    
    topology = {i: {} for i in range(len(atomic_numbers))}
    for (atom_i, atom_j), bond_type in zip(bond_ids, bond_types):
        topology[atom_i][atom_j] = topology[atom_j][atom_i] = bond_type
    
    constituents_dict = {
        "atomic_numbers": atomic_numbers,
        "hydrogen_numbers": hydrogen_numbers,
        "hybridizations": hybridizations,
        "bonds": topology,
    }

    for c in mol.GetConformers():
        structure = np.array(c.GetPositions())

    return constituents_dict, structure

def mol_with_index(mol):
    atoms = mol.GetNumAtoms()
    for idx in range(atoms):
        mol.GetAtomWithIdx(idx).SetProp('molAtomMapNumber', str(mol.GetAtomWithIdx(idx).GetIdx()))

    return mol

## De-novo Design

In [None]:
NSAMPLE_PER_DEVICE = 8 # 128
NSAMPLES = NSAMPLE_PER_DEVICE * NDEVICES

### Constituents Sampling

#### Constituents Sampled from Datset

In [None]:
#### Sample atom constituents from datasets (ZINC-300k)
#### or alternatively, you can sample from ZINC-3m, ZINC-30m
with open('moledit_dataset/constituents/constituents_ZINC_300k.pkl', 'rb') as f:
    constituents_data = pkl.load(f)

constituents = [
    constituents_data[k] for k in 
        np.random.choice(list(constituents_data.keys()), NSAMPLE_PER_DEVICE * NDEVICES)
]

#### Constituents Sampled from Constituents Model

In [None]:
input_dict = {
    "inputs": jnp.ones((NSAMPLES, SEQ_LEN), dtype=jnp.int32), 
    "generation_result": jnp.ones((NSAMPLES, SEQ_LEN), dtype=jnp.int32)
}
input_dict["inputs"] = input_dict["inputs"].at[:, NCONSTITUENTS:].set(NATOMS + NRG_VOCABS) #### unk token for rg

rng_keys = jax.random.split(rng_key, NSAMPLES*SEQ_LEN + 1)
rng_keys, rng_key = rng_keys[:NSAMPLES*SEQ_LEN].reshape(NSAMPLES, SEQ_LEN, -1), rng_keys[-1]

if SHARDING:
    #### shard inputs 
    ds_sharding = partial(_sharding, shards=global_sharding)
    input_dict = jax.tree_map(ds_sharding, input_dict)
    rng_keys = ds_sharding(rng_keys)

inv_temperature = 1.25
for step in tqdm(range(SEQ_LEN)):
    logits = jitted_logits_fn(params, 
                              input_dict['inputs'],
                              input_dict['generation_result'])
    if step >= NCONSTITUENTS:
        valid_logits_mask = jnp.zeros_like(logits, dtype=jnp.float32).at[..., -NRG_VOCABS:-1].set(1)
    else:
        valid_logits_mask = jnp.zeros_like(logits, dtype=jnp.float32).at[..., 1:-NRG_VOCABS].set(1)
    logits += (-1e5) * (1.0 - valid_logits_mask)
    sampled_token, rng_keys = top_p_sampling_fn(logits * inv_temperature, rng_keys)
    input_dict['generation_result'] = input_dict['generation_result'].at[..., step].set(sampled_token[..., step])

generation_result = np.array(input_dict['generation_result']) - 1
constituents = []
for seq in tqdm(generation_result):
    atomic_numbers, hydrogen_numbers, hybridizations = [], [], []
    n_atoms = 0
    #### decode constituents
    for token, num in zip(constituent_vocab_list, seq[:NCONSTITUENTS]):
        atomic_number, hydrogen_number, hybridization = tuple([int(x) for x in token.split('_')])
        atomic_numbers += [atomic_number,] * num 
        hydrogen_numbers += [hydrogen_number,] * num 
        hybridizations += [hybridization,] * num
        n_atoms += num 
        
    #### decode rg
    rg_seq = seq[-NRG_TOKENS:] - NATOMS
    # print(rg_seq)
    rg = np.exp(rg_seq[0]) * float("{}.{}".format(rg_seq[1], "".join([str(x) for x in rg_seq[2:]])))
    constituents.append(
        {"atomic_numbers": np.array(atomic_numbers, dtype=np.uint8), 
         "hydrogen_numbers": np.array(hydrogen_numbers, dtype=np.uint8),
         "hybridizations": np.array(hybridizations, dtype=np.uint8), 
         "radius_of_gyrations": np.array([rg], dtype=np.float32)}
    )

### Structure Sampling

In [None]:
NATOMS = 64 
INFERENCE_METHOD = "DPM_3"

#### jit and vmap functions
def score_forward_fn(atom_feat, bond_feat, x, atom_mask, sigma, rg):
    cond_list = [sigma < noise_thresholds[0],] + \
                [jnp.logical_and(sigma >= noise_thresholds[i], sigma < noise_thresholds[i+1]) for i in range(0, len(noise_thresholds) - 1)] + \
                [sigma >= noise_thresholds[-1],]
    value_list = [net.apply(
                    {}, atom_feat, bond_feat, x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    
    return jnp.sum(jnp.array(cond_list, dtype=jnp.float32)[..., None, None] * \
                    jnp.array(value_list, jnp.float32), axis=0)

score_forward_fn_jvj = jax.jit(jax.vmap(jax.jit(score_forward_fn)))
if INFERENCE_METHOD == "DPM_3":
    inference_fn = partial(DPM_3_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "Langevin":
    inference_fn = partial(Langevin_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=1000, shard_inputs=SHARDING)

In [None]:
print("Example constituents: ")
print("\tatomic numbers: ", constituents[0]['atomic_numbers'])
print("\thydrogen numbers: ", constituents[0]['hydrogen_numbers'])
print("\thybridizaions: ", constituents[0]['hybridizations'])
print("\t**REMARK**: hybridization symbols are same with RDkit")

from inference.utils import preprocess_data

print("Preprocessing inputs")
input_dicts = [preprocess_data(c, NATOMS) for c in tqdm(constituents)]
input_dict = {
    k: np.stack([d[k] for d in input_dicts]) for k in input_dicts[0].keys()
}

print("input shape & dtypes: ")
for k, v in input_dict.items():
    print("\t{} shape: {} dtype: {}".format(k, v.shape, v.dtype))

In [None]:
input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
#### JAX compiles a jitted function when you call it first time.
#### so it will be slow when you run this block first time.
structures, trajectories, rng_key = inference_fn(input_dict, rng_key)

#### save results 
with open('results/de_novo_design/de_novo_design.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'constituents': constituents, 'trajectories': trajectories, 'structures': structures}), f)

### Graph Assembly

In [None]:
import Xponge
from graph_assembler.graph_assembler import assemble_mol_graph

success_or_not = []
smileses = []
for i, (atomic_numbers, hydrogen_numbers, structure) in tqdm(enumerate(zip([c['atomic_numbers'] for c in constituents],
                                                                           [c['hydrogen_numbers'] for c in constituents],
                                                                           structures))):
    success, Xponge_mol, smiles = assemble_mol_graph(atomic_numbers, hydrogen_numbers, structure)
    success_or_not.append(success) 
    smileses.append("" if not success else smiles)

### View Structures

In [None]:
#### view trajectories | structures
import MDAnalysis as mda 
import nglview as nv 

#### load your results 
with open('results/de_novo_design/de_novo_design.pkl', 'rb') as f: 
    results = pkl.load(f)
    constituents = results['constituents']
    trajectories = results['trajectories']
    structures = results['structures']

elements = {
    6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 
    35: 'Br', 53: 'I'
}

mol_id = 0
atomic_numbers = constituents[mol_id]['atomic_numbers']
hydrogen_numbers = constituents[mol_id]['hydrogen_numbers']
n_atoms = len(atomic_numbers)
trajectory = np.array(trajectories)[:, mol_id, :n_atoms, :]
structure = np.array(structures)[mol_id, :n_atoms, :]
rg = np.sqrt(np.sum(structure ** 2) / n_atoms)
print("This is a molecule with {} atoms, rg = {:.2f}/{:.2f} ang".format(n_atoms, constituents[mol_id]['radius_of_gyrations'][0], rg))
print("SMILES: {}".format(smileses[mol_id]))
print("WARNING: bonds provided by NGLViewer may be problematic")

mol = mda.Universe.empty(n_atoms=len(atomic_numbers))
mol.add_TopologyAttr('names', ["{}H{}".format(elements[n], hydrogen_numbers[i]) for i, n in enumerate(atomic_numbers)])
# mol.add_TopologyAttr('names', ["{}".format(elements[n]) for i, n in enumerate(atomic_numbers)])
# mol.load_new(trajectory - np.mean(trajectory, axis=1, keepdims=True)) ### view trajectories 
mol.load_new(structure) ### view structures
view = nv.show_mdanalysis(mol)
view

## Structure Rendring & Sampling of Textual Representations (SMILES)

In [None]:
NSAMPLE_PER_DEVICE = 16 # 1 / 8
NSTRUCTURE_PER_SAMPLE = 8 # 1024
NSAMPLES = 128 # 1024
NSAMPLE_PER_BATCH = int(NSAMPLE_PER_DEVICE * NDEVICES)
NSTRUCTURE_PER_BATCH = NSTRUCTURE_PER_SAMPLE * NSAMPLE_PER_BATCH
NBATCHES = NSAMPLES // NSAMPLE_PER_BATCH
NSTRUCTURES = NSAMPLES * NSTRUCTURE_PER_SAMPLE
INFERENCE_METHOD = "DPM_3"

print("NSAMPLE_PER_DEVICE: {}".format(NSAMPLE_PER_DEVICE))
print("NSAMPLES: {}".format(NSAMPLES))
print("NSAMPLE_PER_BATCH: {}".format(NSAMPLE_PER_BATCH))
print("NBATCHES: {}".format(NBATCHES))

#### jit and vmap functions
def score_forward_fn(atom_feat, bond_feat, x, atom_mask, sigma, rg, gamma=1.0):
    cond_list = [sigma < noise_thresholds[0],] + \
                [jnp.logical_and(sigma >= noise_thresholds[i], sigma < noise_thresholds[i+1]) for i in range(0, len(noise_thresholds) - 1)] + \
                [sigma >= noise_thresholds[-1],]
    value_list = [net.apply(
                    {}, atom_feat, bond_feat, x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value_unc_list = [net.apply(
                    {}, atom_feat, jnp.zeros_like(bond_feat), x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value = gamma * jnp.array(value_list, jnp.float32) +\
                (1.0 - gamma) * jnp.array(value_unc_list, jnp.float32)
    
    return jnp.sum(jnp.array(cond_list, dtype=jnp.float32)[..., None, None] * value, axis=0)

score_forward_fn_jvj = jax.jit(jax.vmap(jax.jit(score_forward_fn)))
if INFERENCE_METHOD == "DPM_3":
    inference_fn = partial(DPM_3_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "DPM_pp_2S":
    inference_fn = partial(DPM_pp_2S_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "Langevin":
    inference_fn = partial(Langevin_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=1000, shard_inputs=SHARDING)

### Test SMILES

In [None]:
with open('./moledit_dataset/smileses/rdkit_failed_smileses.pkl', 'rb') as f:
    test_smileses = [str(x).strip() for x in pkl.load(f)[:NSAMPLES]]

### Constituents Sampling

In [None]:
constituents_dicts = {}
constituents_arrs = []

### preprocess smiles
for i, smi in tqdm(enumerate(test_smileses)):
    constituents_dict = SMILES_to_constituents(smi)
    constituents_dicts[smi] = constituents_dict

    #### sample rg 
    constituents_str = np.array(["{}_{}_{}".format(i,j,k) for i,j,k in zip(constituents_dict['atomic_numbers'], 
                                                                           constituents_dict['hydrogen_numbers'], 
                                                                           constituents_dict['hybridizations'])])
    constituents_arrs.append(np.array([np.sum(constituents_str == v) for v in constituent_vocab_list]))
constituents_arrs = jnp.array(constituents_arrs, dtype=jnp.int32)

### sample rgs
input_dict = {
    "inputs": jnp.ones((NSTRUCTURES, SEQ_LEN), dtype=jnp.int32), 
    "generation_result": jnp.ones((NSTRUCTURES, SEQ_LEN), dtype=jnp.int32)
}
input_dict["generation_result"] = input_dict["generation_result"].at[:, :NCONSTITUENTS].set(
                                            jnp.repeat(constituents_arrs, NSTRUCTURE_PER_SAMPLE, axis=0) + 1)
input_dict["inputs"] = input_dict["inputs"].at[:, NCONSTITUENTS:].set(NATOMS + NRG_VOCABS) #### unk token for rg

inv_temperature = 1.25
generation_results = []
rng_keys, rng_key = split_rngs(rng_key, (NBATCHES, ))
for b in tqdm(range(NBATCHES)):
    input_dict_ = jax.tree_map(lambda x:x[b*NSTRUCTURE_PER_BATCH:(b+1)*NSTRUCTURE_PER_BATCH], input_dict)
    rng_key_b = rng_keys[b]
    rng_keys_b, _ = split_rngs(rng_key_b, (NSTRUCTURE_PER_BATCH, SEQ_LEN))
    
    if SHARDING:
        #### shard inputs 
        ds_sharding = partial(_sharding, shards=global_sharding)
        input_dict_ = jax.tree_map(ds_sharding, input_dict_)
        rng_keys_b = ds_sharding(rng_keys_b)

    
    for step in range(NCONSTITUENTS, SEQ_LEN):
        logits = jitted_logits_fn(params, 
                                  input_dict_['inputs'],
                                  input_dict_['generation_result'])
        valid_logits_mask = jnp.zeros_like(logits, dtype=jnp.float32).at[..., -NRG_VOCABS:-1].set(1)
        logits += (-1e5) * (1.0 - valid_logits_mask)
        sampled_token, rng_keys_b = top_p_sampling_fn(logits * inv_temperature, rng_keys_b)
        input_dict_['generation_result'] = \
            input_dict_['generation_result'].at[..., step].set(sampled_token[..., step])
    generation_results.append(input_dict_['generation_result'])
    
generation_result = np.array(jnp.concatenate(generation_results, axis=0)) - 1
for i, (smi, seqs) in tqdm(enumerate(zip(test_smileses, generation_result.reshape(NSAMPLES,NSTRUCTURE_PER_SAMPLE,-1)))):    
    #### decode rg
    rg_seqs = seqs[:, -NRG_TOKENS:] - NATOMS
    rgs = []
    for rg_seq in rg_seqs:
        rg = np.exp(rg_seq[0]) * float("{}.{}".format(rg_seq[1], "".join([str(x) for x in rg_seq[2:]])))
        rgs.append(rg)
    constituents_dicts[smi].update({"radius_of_gyrations": np.array(rgs, dtype=np.float32)})

### Structure Sampling

In [None]:
print("Example constituents: ")
print("\tatomic numbers: ", constituents_dicts[test_smileses[0]]['atomic_numbers'])
print("\thydrogen numbers: ", constituents_dicts[test_smileses[0]]['hydrogen_numbers'])
print("\thybridizaions: ", constituents_dicts[test_smileses[0]]['hybridizations'])
print("\tradius of gyrations: ", constituents_dicts[test_smileses[0]]['radius_of_gyrations'])
print("\tbonds: ", constituents_dicts[test_smileses[0]]['bonds'])
print("\t**REMARK**: hybridization symbols are same with RDkit")

from inference.utils import preprocess_data

print("Preprocessing inputs")
input_dicts = []
for smi in tqdm(test_smileses):
    d = constituents_dicts[smi]
    input_dicts.extend([preprocess_data({**{k: v for k, v in d.items() if k != 'radius_of_gyrations'}, 
                                         "radius_of_gyrations": [d['radius_of_gyrations'][i]]}, NATOMS) for i in range(NSTRUCTURE_PER_SAMPLE)])
input_dict = {
    k: np.stack([d[k] for d in input_dicts]) for k in input_dicts[0].keys()
}

print("input shape & dtypes: ")
for k, v in input_dict.items():
    print("\t{} shape: {} dtype: {}".format(k, v.shape, v.dtype))

#### With DPM Solver

In [None]:
rng_keys, rng_key = split_rngs(rng_key, (NBATCHES,))
input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
structures, trajectories = [], []

#### JAX compiles a jitted function when you call it first time.
### so it will be slow when you run this block first time.
print("inference requires {} batches running on {} GPUs".format(NBATCHES, NDEVICES))
for b in range(NBATCHES):
    structures_, trajectories_, _ = inference_fn(
        jax.tree_map(lambda x:x[b*NSTRUCTURE_PER_BATCH:(b+1)*NSTRUCTURE_PER_BATCH], input_dict), rng_keys[b])
    structures.append(structures_)
    # trajectories.append(trajectories_) # if you want to save trajectories

In [None]:
structures = np.array(structures).reshape(NSAMPLES, NSTRUCTURE_PER_SAMPLE, NATOMS, 3)
trajectories = np.array(trajectories).transpose((1, 0, 2, 3, 4)).reshape(-1, NSAMPLES, NSTRUCTURE_PER_SAMPLE, NATOMS, 3) if len(trajectories) > 0 \
                else np.array(trajectories)
# structures, trajectories = [], []

#### save results 
with open(f'results/structure_rendering/result_rdkit_failed_smiles.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'smiles': test_smileses,
                           'constituents': [{k: constituents_dicts[smi][k] for k in constituents_dicts[smi].keys() if k != 'bonds'} 
                                            for smi in  test_smileses],
                           'conditional_infos': [{'bonds': constituents_dicts[smi]['bonds']} for smi in test_smileses],  
                           'trajectories': trajectories, 
                           'structures': structures}), f)

#### With FPS Solver

In [None]:
import numpy as np
import Xponge
from Xponge.helper import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Lipinski import RotatableBondSmarts
from rdkit.Geometry import Point3D
from rdkit.Chem import rdMolTransforms

def get_rotable_dihedrals(mol):
    bond_ids = np.array([(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()], dtype=np.uint8)
    bond_types = np.array([int(bond.GetBondType()) for bond in mol.GetBonds()], dtype=np.uint8)
    bonds = {i: {} for i in range(len(mol.GetAtoms()))}
    for ((atom_i, atom_j), bond_type) in zip(bond_ids, bond_types):
        bonds[atom_i][atom_j] = bonds[atom_j][atom_i] = bond_type

    rotable_bonds = mol.GetSubstructMatches(RotatableBondSmarts)
    rotable_bonds_dihedral_ids = []
    for (atom_j, atom_k) in rotable_bonds:
        for atom_i in bonds[atom_j].keys():
            if atom_i == atom_k: continue 
            for atom_l in bonds[atom_k].keys():
                if atom_l == atom_j: continue
                rotable_bonds_dihedral_ids.append([atom_i, atom_j, atom_k, atom_l])
    return np.array(rotable_bonds_dihedral_ids, dtype=np.int32)

mol = Chem.MolFromSmiles(test_smileses[0])
rotable_bonds_dihedral_ids = get_rotable_dihedrals(mol)
rotable_bonds_dihedral_ids = jnp.array(rotable_bonds_dihedral_ids)

In [None]:
fp_coeff = 2e3  ### This coeficient controls the additional Fokker-Planck-gradient term in SDE
rng_keys, rng_key = split_rngs(rng_key, (NBATCHES,))
input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
structures, trajectories = [], []

#### JAX compiles a jitted function when you call it first time.
### so it will be slow when you run this block first time.
print("inference requires {} batches running on {} GPUs".format(NBATCHES, NDEVICES))
for b in range(NBATCHES):
    structures_, trajectories_, _ = inference_fn(
        jax.tree_map(lambda x:x[b*NSTRUCTURE_PER_BATCH:(b+1)*NSTRUCTURE_PER_BATCH], input_dict), rng_keys[b], 
        dihedral_dict={'dihedral_atom_ids': rotable_bonds_dihedral_ids, 'fp_coeff': fp_coeff})
    structures.append(structures_)
    # trajectories.append(trajectories_) # if you want to save trajectories

In [None]:
structures = np.array(structures).reshape(NSAMPLES, NSTRUCTURE_PER_SAMPLE, NATOMS, 3)
trajectories = np.array(trajectories).transpose((1, 0, 2, 3, 4)).reshape(-1, NSAMPLES, NSTRUCTURE_PER_SAMPLE, NATOMS, 3) if len(trajectories) > 0 \
                else np.array(trajectories)
# structures, trajectories = [], []

#### save results 
with open(f'results/structure_rendering/result_rdkit_failed_smiles_fps.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'smiles': test_smileses,
                           'constituents': [{k: constituents_dicts[smi][k] for k in constituents_dicts[smi].keys() if k != 'bonds'} 
                                            for smi in  test_smileses],
                           'conditional_infos': [{'bonds': constituents_dicts[smi]['bonds']} for smi in test_smileses],  
                           'trajectories': trajectories, 
                           'structures': structures}), f)

### View Structures

In [None]:
#### view trajectories | structures
import MDAnalysis as mda 
import nglview as nv 

#### load your results 
# with open(f'results/structure_rendering/result_rdkit_failed_smiles.pkl', 'rb') as f:
with open(f'results/structure_rendering/result_rdkit_failed_smiles_fps.pkl', 'rb') as f:
    results = pkl.load(f)
    test_smileses = results['smiles']
    constituents = results['constituents']
    conditional_infos = results['conditional_infos']
    # trajectories = results['trajectories']
    structures = results['structures']

elements = {
    6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 
    35: 'Br', 53: 'I'
}

mol_id, structure_id = 0, 0
print("smiles: ", test_smileses[mol_id], len(test_smileses))
atomic_numbers = constituents[mol_id]['atomic_numbers']
hydrogen_numbers = constituents[mol_id]['hydrogen_numbers']
hybridizations = constituents[mol_id]['hybridizations']
print("atomic numbers:", ",".join([str(x) for x in atomic_numbers]))
print("hydrogen_numbers:", ",".join([str(x) for x in hydrogen_numbers]))
print("hybridizations:", ",".join([str(x) for x in hybridizations]))
n_atoms = len(atomic_numbers)
# trajectory = np.array(trajectories)[:, mol_id, structure_id, :n_atoms, :]
structure = np.array(structures)[mol_id, structure_id, :n_atoms, :]
rg = np.sqrt(np.sum(structure ** 2) / n_atoms)
print("This is a molecule with {} atoms, rg = {:.2f}/{:.2f} ang".format(n_atoms, rg, constituents[mol_id]['radius_of_gyrations'][structure_id]))
print("Conditional infos indicate the following bonds")
print("WARNING: bonds provided by NGLViewer may be problematic")

mol = mda.Universe.empty(n_atoms=len(atomic_numbers))
mol.add_TopologyAttr('names', ["{}{}".format(elements[n], i) for i, n in enumerate(atomic_numbers)])
# mol.add_TopologyAttr('names', ["{}".format(elements[n]) for i, n in enumerate(atomic_numbers)])
# mol.load_new(trajectory - np.mean(trajectory, axis=1, keepdims=True)) ### view trajectories 
mol.load_new(structure) ### view structures
view = nv.show_mdanalysis(mol)
view

## Structure Editing

### Chair/Twist-boat

In [None]:
NSAMPLE_PER_DEVICE = 1
NSAMPLES = NSAMPLE_PER_DEVICE * NDEVICES
NATOMS = 64 
INFERENCE_METHOD = "DPM_3"

#### jit and vmap functions
def score_forward_fn(atom_feat, bond_feat, x, atom_mask, sigma, rg, gamma=1.0):
    cond_list = [sigma < noise_thresholds[0],] + \
                [jnp.logical_and(sigma >= noise_thresholds[i], sigma < noise_thresholds[i+1]) for i in range(0, len(noise_thresholds) - 1)] + \
                [sigma >= noise_thresholds[-1],]
    value_list = [net.apply(
                    {}, atom_feat, bond_feat, x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value_unc_list = [net.apply(
                    {}, atom_feat, jnp.zeros_like(bond_feat), x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value = gamma * jnp.array(value_list, jnp.float32) +\
                (1.0 - gamma) * jnp.array(value_unc_list, jnp.float32)
    
    return jnp.sum(jnp.array(cond_list, dtype=jnp.float32)[..., None, None] * value, axis=0)

score_forward_fn_jvj = jax.jit(jax.vmap(jax.jit(score_forward_fn)))
if INFERENCE_METHOD == "DPM_3":
    inference_fn = partial(DPM_3_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "DPM_pp_2S":
    inference_fn = partial(DPM_pp_2S_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "Langevin":
    inference_fn = partial(Langevin_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=1000, shard_inputs=SHARDING)

In [None]:
constituents_dict = SMILES_to_constituents('O=CC1C(O)C(O)C(NC(C)=O)C(OC2=CC=C(NC3=C(C=C(Cl)C(Cl)=C4)C4=NC=C3C(O)=O)C=C2)O1')

atomic_numbers = constituents_dict['atomic_numbers']
hydrogen_numbers = constituents_dict['hydrogen_numbers']
hybridizations = constituents_dict['hybridizations']
bonds = constituents_dict['bonds']

n_atoms = len(atomic_numbers)
rg = 4.8

In [None]:
mol_with_index(Chem.MolFromSmiles('O=CC1C(O)C(O)C(NC(C)=O)C(OC2=CC=C(NC3=C(C=C(Cl)C(Cl)=C4)C4=NC=C3C(O)=O)C=C2)O1'))

In [None]:
twist_boat = np.array([
    [4.2964,    6.7528,   -3.7543], 
    [3.3185,    5.7022,   -3.2583],
    [3.1650,    5.6818,   -1.7379],
    [4.3181,    6.1715,   -1.1170],
    [4.5819,    7.5272,   -1.3917],
    [4.3223,    7.9427,   -2.8203],
])

chair = np.array([
    [-2.6536,    0.1260,   0.2929], 
    [-3.1483,    -1.2883,  0.4559],
    [-4.5988,    -1.2378,  0.8798],
    [-5.3909,    -0.4239,  -0.1226],
    [-4.7649,    0.9549,   -0.3028],
    [-3.4130,    0.8333,   -0.6571],
])

idx_map = {2:0, 3:1, 5:2, 7: 3, 12:4, 36: 5}

repaint_info = {
    "structure": np.stack([chair[idx_map[i]] if i in idx_map.keys() else np.zeros(3) 
                                 for i in range(NATOMS)]).astype(np.float32),
    "mask": np.array([True if i in idx_map.keys() else False for i in range(NATOMS)]).astype(np.bool_), 
}

from inference.utils import preprocess_data

print("Preprocessing inputs")
raw_info_dict = {"atomic_numbers": atomic_numbers, 
                 "hydrogen_numbers": hydrogen_numbers,
                 "hybridizations": hybridizations, 
                 "radius_of_gyrations": [rg, rg], 
                 "bonds": bonds}

input_dict = preprocess_data(raw_info_dict, NATOMS)
input_dict = jax.tree_map(lambda x:np.repeat(x[None, ...], 
                                             NSAMPLE_PER_DEVICE * NDEVICES, axis=0), input_dict)
repaint_dict = jax.tree_map(lambda x:np.repeat(x[None, ...], 
                                               NSAMPLE_PER_DEVICE * NDEVICES, axis=0), repaint_info)

print("input shape & dtypes: ")
for k, v in input_dict.items():
    print("\t{} shape: {} dtype: {}".format(k, v.shape, v.dtype))
for k, v in repaint_dict.items():
    print("\tRePaint {} shape: {} dtype: {}".format(k, v.shape, v.dtype))

In [None]:
input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
repaint_dict = jax.tree_map(lambda x:jnp.array(x), repaint_dict)
#### JAX compiles a jitted function when you call it first time.
#### so it will be slow when you run this block first time.
structures, trajectories, rng_key = inference_fn(input_dict, rng_key, repaint_dict=repaint_dict)

#### save results 
with open(f'results/structure_editing/chair.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'raw_infos': {**raw_info_dict, "repaint": repaint_info},
                           'trajectories': trajectories, 'structures': structures}), f)

#### View Structures

In [None]:
#### view trajectories | structures
import MDAnalysis as mda 
import nglview as nv 

#### load your results 
with open(f'results/structure_editing/chair.pkl', 'rb') as f: 
    results = pkl.load(f)
    atomic_numbers = results['raw_infos']['atomic_numbers']
    hydrogen_numbers = results['raw_infos']['hydrogen_numbers']
    trajectories = results['trajectories']
    structures = results['structures']

elements = {
    6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 
    35: 'Br', 53: 'I'
}

mol_id = 0
n_atoms = len(atomic_numbers)

structure = np.array(structures)[mol_id, :n_atoms, :]
rg = np.sqrt(np.sum(structure ** 2) / n_atoms)
print("This is a molecule with {} atoms, rg = {:.2f} ang".format(n_atoms, rg))
print("WARNING: bonds provided by NGLViewer may be problematic")

mol = mda.Universe.empty(n_atoms=len(atomic_numbers))
mol.add_TopologyAttr('names', ["{}{}".format(elements[int(n)], i) for i, n in enumerate(atomic_numbers)])
# mol.load_new(trajectory - np.mean(trajectory, axis=1, keepdims=True)) ### view trajectories 
mol.load_new(structure) ### view structures
view = nv.show_mdanalysis(mol)
view

### E/Z conformation

In [None]:
NSAMPLE_PER_DEVICE = 1
NSAMPLES = NSAMPLE_PER_DEVICE * NDEVICES
NATOMS = 64 
INFERENCE_METHOD = "DPM_3"

#### jit and vmap functions
def score_forward_fn(atom_feat, bond_feat, x, atom_mask, sigma, rg, gamma=1.0):
    cond_list = [sigma < noise_thresholds[0],] + \
                [jnp.logical_and(sigma >= noise_thresholds[i], sigma < noise_thresholds[i+1]) for i in range(0, len(noise_thresholds) - 1)] + \
                [sigma >= noise_thresholds[-1],]
    value_list = [net.apply(
                    {}, atom_feat, bond_feat, x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value_unc_list = [net.apply(
                    {}, atom_feat, jnp.zeros_like(bond_feat), x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value = gamma * jnp.array(value_list, jnp.float32) +\
                (1.0 - gamma) * jnp.array(value_unc_list, jnp.float32)
    
    return jnp.sum(jnp.array(cond_list, dtype=jnp.float32)[..., None, None] * value, axis=0)

score_forward_fn_jvj = jax.jit(jax.vmap(jax.jit(score_forward_fn)))
if INFERENCE_METHOD == "DPM_3":
    inference_fn = partial(DPM_3_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "DPM_pp_2S":
    inference_fn = partial(DPM_pp_2S_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "Langevin":
    inference_fn = partial(Langevin_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=1000, shard_inputs=SHARDING)

In [None]:
constituents_dict = SMILES_to_constituents('Fc1ccc(/C=C/C[NH2+][C@@H]2[C@H]3C[C@@H]4CO[C@@H]2[C@H]4C3)cc1F')

atomic_numbers = constituents_dict['atomic_numbers']
hydrogen_numbers = constituents_dict['hydrogen_numbers']
hybridizations = constituents_dict['hybridizations']
bonds = constituents_dict['bonds']

n_atoms = len(atomic_numbers)
rg = 3.8

In [None]:
mol_with_index(Chem.MolFromSmiles('Fc1ccc(/C=C/C[NH2+][C@@H]2[C@H]3C[C@@H]4CO[C@@H]2[C@H]4C3)cc1F'))

In [None]:
E_double_bonds = np.array([
   [-2.226042  ,  1.2258013 ,  0.22110261],
   [-2.116217  , -0.18059452,  0.672148  ],
   [-1.52214   , -1.1383319 , -0.04937573],
   [-1.0976883 , -2.4374735 ,  0.586352  ]
])
Z_double_bonds = np.array([
   [ 0.4667195 ,  0.8490278 , -2.3810658 ],
   [-0.9392601 ,  1.1926131 , -2.6590502 ],
   [-1.9503369 ,  1.0611756 , -1.7846249 ],
   [-1.7876288 ,  0.67973304, -0.33311933]
])
idx_map = {4: 0, 5: 1, 6: 2, 7: 3}

repaint_info = {
    "structure": np.stack([Z_double_bonds[idx_map[i]] if i in idx_map.keys() else np.zeros(3) 
                                 for i in range(NATOMS)]).astype(np.float32),
    "mask": np.array([True if i in idx_map.keys() else False for i in range(NATOMS)]).astype(np.bool_), 
}

from inference.utils import preprocess_data

print("Preprocessing inputs")
raw_info_dict = {"atomic_numbers": atomic_numbers, 
                 "hydrogen_numbers": hydrogen_numbers,
                 "hybridizations": hybridizations, 
                 "radius_of_gyrations": [rg, rg], 
                 "bonds": bonds}

input_dict = preprocess_data(raw_info_dict, NATOMS)
input_dict = jax.tree_map(lambda x:np.repeat(x[None, ...], 
                                             NSAMPLE_PER_DEVICE * NDEVICES, axis=0), input_dict)
repaint_dict = jax.tree_map(lambda x:np.repeat(x[None, ...], 
                                               NSAMPLE_PER_DEVICE * NDEVICES, axis=0), repaint_info)

print("input shape & dtypes: ")
for k, v in input_dict.items():
    print("\t{} shape: {} dtype: {}".format(k, v.shape, v.dtype))
for k, v in repaint_dict.items():
    print("\tRePaint {} shape: {} dtype: {}".format(k, v.shape, v.dtype))

In [None]:
input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
repaint_dict = jax.tree_map(lambda x:jnp.array(x), repaint_dict)
#### JAX compiles a jitted function when you call it first time.
#### so it will be slow when you run this block first time.
structures, trajectories, rng_key = inference_fn(input_dict, rng_key, repaint_dict=repaint_dict)

#### save results 
with open(f'results/structure_editing/Z_conf.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'raw_infos': {**raw_info_dict, "repaint": repaint_info},
                           'trajectories': trajectories, 'structures': structures}), f)

#### View Structures

In [None]:
#### view trajectories | structures
import MDAnalysis as mda 
import nglview as nv 

#### load your results 
with open(f'results/structure_editing/Z_conf.pkl', 'rb') as f: 
    results = pkl.load(f)
    atomic_numbers = results['raw_infos']['atomic_numbers']
    hydrogen_numbers = results['raw_infos']['hydrogen_numbers']
    trajectories = results['trajectories']
    structures = results['structures']

elements = {
    6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 
    35: 'Br', 53: 'I'
}

mol_id = 0
n_atoms = len(atomic_numbers)

structure = np.array(structures)[mol_id, :n_atoms, :]
rg = np.sqrt(np.sum(structure ** 2) / n_atoms)
print("This is a molecule with {} atoms, rg = {:.2f} ang".format(n_atoms, rg))
print("WARNING: bonds provided by NGLViewer may be problematic")

mol = mda.Universe.empty(n_atoms=len(atomic_numbers))
mol.add_TopologyAttr('names', ["{}{}".format(elements[int(n)], i) for i, n in enumerate(atomic_numbers)])
# mol.load_new(trajectory - np.mean(trajectory, axis=1, keepdims=True)) ### view trajectories 
mol.load_new(structure) ### view structures
view = nv.show_mdanalysis(mol)
view

## Redesign of Aromatic Rings

In [None]:
NSAMPLE_PER_DEVICE = 8
NSAMPLES = NSAMPLE_PER_DEVICE * NDEVICES
NATOMS = 64 
INFERENCE_METHOD = "DPM_3"

#### jit and vmap functions
def score_forward_fn(atom_feat, bond_feat, x, atom_mask, sigma, rg, gamma=1.0):
    cond_list = [sigma < noise_thresholds[0],] + \
                [jnp.logical_and(sigma >= noise_thresholds[i], sigma < noise_thresholds[i+1]) for i in range(0, len(noise_thresholds) - 1)] + \
                [sigma >= noise_thresholds[-1],]
    value_list = [net.apply(
                    {}, atom_feat, bond_feat, x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value_unc_list = [net.apply(
                    {}, atom_feat, jnp.zeros_like(bond_feat), x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value = gamma * jnp.array(value_list, jnp.float32) +\
                (1.0 - gamma) * jnp.array(value_unc_list, jnp.float32)
    
    return jnp.sum(jnp.array(cond_list, dtype=jnp.float32)[..., None, None] * value, axis=0)

score_forward_fn_jvj = jax.jit(jax.vmap(jax.jit(score_forward_fn)))
if INFERENCE_METHOD == "DPM_3":
    inference_fn = partial(DPM_3_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "DPM_pp_2S":
    inference_fn = partial(DPM_pp_2S_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "Langevin":
    inference_fn = partial(Langevin_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=1000, shard_inputs=SHARDING)

In [None]:
mol_with_index(Chem.MolFromSmiles('Cc1cnc2c(cnn2CC(=O)N[C@@H]2[C@H]3C[C@@H]4CO[C@@H]2[C@H]4C3)c1'))

In [None]:
rg = 4.8
### result_1 (6 + 5, N, N, N)
atomic_numbers =   [6,6,6,7,6,6,6,7,7,6,6,8,7,6,6,6,6,6,8,6,6,6,6]
hydrogen_numbers = [3,0,1,0,0,0,1,0,0,2,0,0,1,1,1,2,1,2,0,1,1,2,1]
hybridizations =   [4,3,3,3,3,3,3,3,3,4,3,3,3,4,4,4,4,4,4,4,4,4,3]
### result_2 (6 + 5, N, N, O)
# atomic_numbers =   [6,6,6,8,6,6,7,7,6,6,6,8,7,6,6,6,6,6,8,6,6,6,6]
# hydrogen_numbers = [3,0,1,0,0,0,0,0,0,2,0,0,1,1,1,2,1,2,0,1,1,2,1]
# hybridizations =   [4,3,3,3,3,3,3,3,3,4,3,3,3,4,4,4,4,4,4,4,4,4,3]
### result_3 (5 + 5, N, N, N, O)
# atomic_numbers =   [6,6,7,8,6,6,7,7,6,6,6,8,7,6,6,6,6,6,8,6,6,6]
# hydrogen_numbers = [3,0,1,0,0,0,0,0,0,2,0,0,1,1,1,2,1,2,0,1,1,2]
# hybridizations =   [4,3,3,3,3,3,3,3,3,4,3,3,3,4,4,4,4,4,4,4,4,4]

### aromatic rings: 1,2,3,4,5,6,7,8,22
bonds = {
    0:  { 1: 1 },
	1:  { 0: 1, 2: 12, 22: 12 },
	2:  { 1: 12, 3: 12 },
	3:  { 2: 12, 4: 12 },
	4:  { 3: 12, 5: 12, 8: 12 },
	5:  { 4: 12, 6: 12, 22: 12 },
	6:  { 5: 12, 7: 12 },
	7:  { 6: 12, 8: 12 },
	8:  { 4: 12, 7: 12, 9: 1 },
	9:  { 8: 1, 10: 1 },
	10:  { 9: 1, 11: 2, 12: 1 },
	11:  { 10: 2 },
	12:  { 10: 1, 13: 1 },
	13:  { 12: 1, 14: 1, 19: 1 },
	14:  { 13: 1, 15: 1, 21: 1 },
	15:  { 14: 1, 16: 1 },
	16:  { 15: 1, 17: 1, 20: 1 },
	17:  { 16: 1, 18: 1 },
	18:  { 17: 1, 19: 1 },
	19:  { 13: 1, 18: 1, 20: 1 },
	20:  { 16: 1, 19: 1, 21: 1 },
	21:  { 14: 1, 20: 1 },
	22:  { 1: 12, 5: 12 },
}

#### delete aromatic bonds
bonds = {
    0:  { 1: 1 },
	1:  { 0: 1,},
	2:  {},
	3:  {},
	4:  {},
	5:  {},
	6:  {},
	7:  {},
	8:  { 9: 1 },
	9:  { 8: 1, 10: 1 },
	10:  { 9: 1, 11: 2, 12: 1 },
	11:  { 10: 2 },
	12:  { 10: 1, 13: 1 },
	13:  { 12: 1, 14: 1, 19: 1 },
	14:  { 13: 1, 15: 1, 21: 1 },
	15:  { 14: 1, 16: 1 },
	16:  { 15: 1, 17: 1, 20: 1 },
	17:  { 16: 1, 18: 1 },
	18:  { 17: 1, 19: 1 },
	19:  { 13: 1, 18: 1, 20: 1 },
	20:  { 16: 1, 19: 1, 21: 1 },
	21:  { 14: 1, 20: 1 },
}

In [None]:
from inference.utils import preprocess_data

print("Preprocessing inputs")
raw_info_dict = {"atomic_numbers": atomic_numbers, 
                 "hydrogen_numbers": hydrogen_numbers,
                 "hybridizations": hybridizations, 
                 "radius_of_gyrations": [rg, rg], 
                 "bonds": bonds}

input_dict = preprocess_data(raw_info_dict, NATOMS)
input_dict = jax.tree_map(lambda x:np.repeat(x[None, ...], 
                                             NSAMPLE_PER_DEVICE * NDEVICES, axis=0), input_dict)

print("input shape & dtypes: ")
for k, v in input_dict.items():
    print("\t{} shape: {} dtype: {}".format(k, v.shape, v.dtype))

In [None]:
input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
#### JAX compiles a jitted function when you call it first time.
#### so it will be slow when you run this block first time.
structures, trajectories, rng_key = inference_fn(input_dict, rng_key)

#### save results 
with open(f'results/redesign_aromatic_rings/result_1.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'raw_infos': {**raw_info_dict},
                           'trajectories': trajectories, 'structures': structures}), f)

#### Graph Assembly

In [None]:
from graph_assembler.graph_assembler import assemble_mol_graph, check_bonds

structures = np.array(structures)
success_or_not = []
smileses = []
for i, structure in tqdm(enumerate(structures)):
    success, Xponge_mol, smiles = assemble_mol_graph(atomic_numbers, hydrogen_numbers, structure)
    success_or_not.append(success) 
    smileses.append("" if not success else smiles)

#### View Structures

In [None]:
#### view trajectories | structures
import MDAnalysis as mda 
import nglview as nv 

#### load your results 
with open(f'results/converting_aromatic_rings/result_1.pkl', 'rb') as f: 
    results = pkl.load(f)
    atomic_numbers = results['raw_infos']['atomic_numbers']
    hydrogen_numbers = results['raw_infos']['hydrogen_numbers']
    trajectories = results['trajectories']
    structures = results['structures']

elements = {
    6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 
    35: 'Br', 53: 'I'
}

mol_id = 0
n_atoms = len(atomic_numbers)

structure = np.array(structures)[mol_id, :n_atoms, :]
rg = np.sqrt(np.sum(structure ** 2) / n_atoms)
print("This is a molecule with {} atoms, rg = {:.2f} ang".format(n_atoms, rg))
print("SMILES: {}".format(smileses[mol_id]))
print("WARNING: bonds provided by NGLViewer may be problematic")

mol = mda.Universe.empty(n_atoms=len(atomic_numbers))
mol.add_TopologyAttr('names', ["{}{}".format(elements[int(n)], i) for i, n in enumerate(atomic_numbers)])
# mol.load_new(trajectory - np.mean(trajectory, axis=1, keepdims=True)) ### view trajectories 
mol.load_new(structure) ### view structures
view = nv.show_mdanalysis(mol)
view

## Design with respect to a Functional-Core

In [None]:
NSAMPLE_PER_DEVICE = 8
NSAMPLES = NSAMPLE_PER_DEVICE * NDEVICES
NATOMS = 64 
INFERENCE_METHOD = "DPM_3"

#### jit and vmap functions
def score_forward_fn(atom_feat, bond_feat, x, atom_mask, sigma, rg, gamma=1.0):
    cond_list = [sigma < noise_thresholds[0],] + \
                [jnp.logical_and(sigma >= noise_thresholds[i], sigma < noise_thresholds[i+1]) for i in range(0, len(noise_thresholds) - 1)] + \
                [sigma >= noise_thresholds[-1],]
    value_list = [net.apply(
                    {}, atom_feat, bond_feat, x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value_unc_list = [net.apply(
                    {}, atom_feat, jnp.zeros_like(bond_feat), x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value = gamma * jnp.array(value_list, jnp.float32) +\
                (1.0 - gamma) * jnp.array(value_unc_list, jnp.float32)
    
    return jnp.sum(jnp.array(cond_list, dtype=jnp.float32)[..., None, None] * value, axis=0)

score_forward_fn_jvj = jax.jit(jax.vmap(jax.jit(score_forward_fn)))
if INFERENCE_METHOD == "DPM_3":
    inference_fn = partial(DPM_3_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "DPM_pp_2S":
    inference_fn = partial(DPM_pp_2S_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "Langevin":
    inference_fn = partial(Langevin_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=1000, shard_inputs=SHARDING)

In [None]:
n_atoms = 38
rg = 4.8
atomic_numbers = [6, 6, 6, 6, 6, 6, 7, 6, 6, 7, 6, 6, 6, 6, 6, 17, 17, 6, 6, 6, 8, 6, 8, 8, 6, 8, 6, 6, 6, 6, 7, 6, 8, 6, 8, 8, 6, 8]
hydrogen_numbers = [0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 3, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 3, 1, 1, 2, 1]
hybridizations = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 3, 3, 3, 3, 4, 3, 3, 4, 4, 4, 4, 4, 4, 3, 3, 3, 4, 4, 4, 4, 4]

In [None]:
bonds_0 = {
    23: { 24: 1 },
    24: { 25: 1, 29: 1 },
    25: { 26: 1 },
    26: { 27: 1, 36: 1 },
    27: { 28: 1, 35: 1 },
    28: { 29: 1, 34: 1 },
    29: { 30: 1 },
    30: { 31: 1 },
    31: { 32: 2, 33: 1 },
    36: { 37: 1 },
}
# bonds_1 = {
#     23: { 28: 1 },
#     24: { 25: 1, 29: 1, 34: 1 },
#     25: { 26: 1 },
#     26: { 27: 1, 36: 1 },
#     27: { 28: 1, 35: 1 },
#     28: { 29: 1 },
#     29: { 30: 1 },
#     30: { 31: 1 },
#     31: { 32: 2, 33: 1 },
#     36: { 37: 1 },
# }
# bonds_2 = {
#     23: { 27: 1 },
#     24: { 25: 1, 29: 1, 35: 1 },
#     25: { 26: 1 },
#     26: { 27: 1, 36: 1 },
#     27: { 28: 1 },
#     28: { 29: 1, 34: 1 },
#     29: { 30: 1 },
#     30: { 31: 1 },
#     31: { 32: 2, 33: 1 },
#     36: { 37: 1 },
# }

from inference.utils import preprocess_data

print("Preprocessing inputs")
raw_info_dict = {"atomic_numbers": atomic_numbers, 
                 "hydrogen_numbers": hydrogen_numbers,
                 "hybridizations": hybridizations, 
                 "radius_of_gyrations": [rg, rg],
                 "bonds": bonds_0,}

input_dict = preprocess_data(raw_info_dict, NATOMS)
input_dict = jax.tree_map(lambda x:np.repeat(x[None, ...], 
                                             NSAMPLE_PER_DEVICE * NDEVICES, axis=0), input_dict)

print("input shape & dtypes: ")
for k, v in input_dict.items():
    print("\t{} shape: {} dtype: {}".format(k, v.shape, v.dtype))

In [None]:
input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
#### JAX compiles a jitted function when you call it first time.
#### so it will be slow when you run this block first time.
structures, trajectories, rng_key = inference_fn(input_dict, rng_key)

#### save results 
with open(f'results/design_wrt_functional_core/result_0.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'raw_infos': raw_info_dict,
                           'trajectories': trajectories, 'structures': structures}), f)

### Graph Assembly

In [None]:
from graph_assembler.graph_assembler import assemble_mol_graph, check_bonds

structures = np.array(structures)
success_or_not = []
smileses = []
bonds_correct_or_not = []
for i, structure in tqdm(enumerate(structures)):
    success, Xponge_mol, smiles = assemble_mol_graph(atomic_numbers, hydrogen_numbers, structure)
    success_or_not.append(success) 
    smileses.append("" if not success else smiles)

    #### export to mol2
    if success:
        ##### delete Hs (Hs are added to help recogonizing topology, their coordinates are fake)
        atoms = Xponge_mol.atoms[::1]
        hydrogen_atom_idx = np.sort([idx for idx, atom in enumerate(atoms) if 'H' in atom])[::-1]
        for atom_idx in hydrogen_atom_idx: 
            Xponge_mol.delete_atom(atom_idx)

        ret = check_bonds(Xponge_mol.bonds, bonds, allow_perm=False)
        bonds_correct_or_not.append(ret[0] == ret[1])
    else:
        bonds_correct_or_not.append(False)

### View Structures

In [None]:
#### view trajectories | structures
import MDAnalysis as mda 
import nglview as nv 

#### load your results 
with open(f'results/design_wrt_functional_core/result_0.pkl', 'rb') as f: 
    results = pkl.load(f)
    atomic_numbers = results['raw_infos']['atomic_numbers']
    hydrogen_numbers = results['raw_infos']['hydrogen_numbers']
    trajectories = results['trajectories']
    structures = results['structures']

elements = {
    6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 
    35: 'Br', 53: 'I'
}

mol_id = 5
n_atoms = len(atomic_numbers)

structure = np.array(structures)[mol_id, :n_atoms, :]
rg = np.sqrt(np.sum(structure ** 2) / n_atoms)
print("This is a molecule with {} atoms, rg = {:.2f} ang".format(n_atoms, rg))
print("SMILES: {}".format(smileses[mol_id]))
print("WARNING: bonds provided by NGLViewer may be problematic")

mol = mda.Universe.empty(n_atoms=len(atomic_numbers))
mol.add_TopologyAttr('names', ["{}{}".format(elements[int(n)], i) for i, n in enumerate(atomic_numbers)])
# mol.load_new(trajectory - np.mean(trajectory, axis=1, keepdims=True)) ### view trajectories 
mol.load_new(structure) ### view structures
view = nv.show_mdanalysis(mol)
view

## Linker Design

In [None]:
NSAMPLE_PER_DEVICE = 8
NSAMPLES = NSAMPLE_PER_DEVICE * NDEVICES

In [None]:
from rdkit import Chem 
from rdkit.Chem import Draw

smi = 'CC(NC1=CC(N=C(C2=CC=CC=C2)O3)=C3C=C1)=O' # sys 0
# smi = 'CC(C)C1=CC=C(C(C)CC(O)C)C=C1OC' # sys 1
mol_with_index(Chem.MolFromSmiles(smi))

In [None]:
constituents_dict = SMILES_to_constituents(smi)
atomic_numbers = constituents_dict['atomic_numbers']
hydrogen_numbers = constituents_dict['hydrogen_numbers']
hybridizations = constituents_dict['hybridizations']
bonds = constituents_dict['bonds']
print("atomic numbers:  ", ",".join([str(x) for x in atomic_numbers]))
print("hydrogen_numbers:", ",".join([str(x) for x in hydrogen_numbers]))
print("hybridizations:  ", ",".join([str(x) for x in hybridizations]))
for atom_a, bonds_a in bonds.items():
    if len(bonds_a.keys()) > 0:
        print("\t{}: ".format(atom_a), "{", ", ".join(["{}: {}".format(atom_b, btype) for atom_b, btype in bonds_a.items()]), "},")

In [None]:
### design linker between atom 7 - atom 8 (sys 0)
atomic_numbers = [6,6,7,6,6,6,7,6,6,6,6,6,6,6,8,6,6,6,8]
hydrogen_numbers = [3,0,1,0,1,0,0,0,0,1,1,1,1,1,0,0,1,1,0]
hybridizations = [4,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3]

bonds = {
    0:  { 1: 1 },
	1:  { 0: 1, 2: 1, 18: 2 },
	2:  { 1: 1, 3: 1 },
	3:  { 2: 1, 4: 12, 17: 12 },
	4:  { 3: 12, 5: 12 },
	5:  { 4: 12, 6: 12, 15: 12 },
	6:  { 5: 12, 7: 12 },
	7:  { 6: 12, 14: 12 },
	8:  { 9: 12, 13: 12 },
	9:  { 8: 12, 10: 12 },
	10:  { 9: 12, 11: 12 },
	11:  { 10: 12, 12: 12 },
	12:  { 11: 12, 13: 12 },
	13:  { 12: 12, 8: 12 },
	14:  { 7: 12, 15: 12 },
	15:  { 5: 12, 16: 12, 14: 12 },
	16:  { 15: 12, 17: 12 },
	17:  { 16: 12, 3: 12 },
	18:  { 1: 2 },
}
rg = 5.0
exponents = math.floor(math.log(rg))
numbers = np.round(rg / math.exp(exponents), 1)
numbers_str = str(numbers) ### x.yz...
rg_tokens = [int(exponents), int(numbers_str[0]), int(numbers_str[2])]

# ### design linker between atom 6 - atom 7 (sys 1)
# atomic_numbers = [6,6,6,6,6,6,6,6,6,6,6,8,6,6,6,8,6]
# hydrogen_numbers = [3,1,3,0,1,1,0,1,3,2,1,1,3,1,0,0,3]
# hybridizations = [4,4,4,3,3,3,3,4,4,4,4,4,4,3,3,3,4] 
# bonds = {
#     00:  { 1: 1 },
# 	1:  { 0: 1, 2: 1, 3: 1 },
# 	2:  { 1: 1 },
# 	3:  { 1: 1, 4: 12, 14: 12 },
# 	4:  { 3: 12, 5: 12 },
# 	5:  { 4: 12, 6: 12 },
# 	6:  { 5: 12, 13: 12 },
# 	7:  { 8: 1, 9: 1 },
# 	8:  { 7: 1 },
# 	9:  { 7: 1, 10: 1 },
# 	10:  { 9: 1, 11: 1, 12: 1 },
# 	11:  { 10: 1 },
# 	12:  { 10: 1 },
# 	13:  { 6: 12, 14: 12 },
# 	14:  { 13: 12, 15: 1, 3: 12 },
# 	15:  { 14: 1, 16: 1 },
# 	16:  { 15: 1 },
# }
# rg = 5.0
# exponents = math.floor(math.log(rg))
# numbers = np.round(rg / math.exp(exponents), 1)
# numbers_str = str(numbers) ### x.yz...
# rg_tokens = [int(exponents), int(numbers_str[0]), int(numbers_str[2])]

### Constituents Sampling

In [None]:
group_constituents_str = ["{}_{}_{}".format(i, j, k) for i, j, k in zip(atomic_numbers, 
                                                                        hydrogen_numbers, 
                                                                        hybridizations)]
inputs = np.array([group_constituents_str.count(t) for t in constituent_vocab_list] + (np.array(rg_tokens) + NATOMS).tolist(), 
                  dtype=np.int32)
input_dict = {
    "inputs": jnp.array([inputs,] * NSAMPLES, dtype=jnp.int32) + 1, 
    "generation_result": jnp.ones((NSAMPLES, SEQ_LEN), dtype=jnp.int32)
}

rng_keys = jax.random.split(rng_key, NSAMPLES*SEQ_LEN + 1)
rng_keys, rng_key = rng_keys[:NSAMPLES*SEQ_LEN].reshape(NSAMPLES, SEQ_LEN, -1), rng_keys[-1]

if SHARDING:
    #### shard inputs 
    ds_sharding = partial(_sharding, shards=global_sharding)
    input_dict = jax.tree_map(ds_sharding, input_dict)
    rng_keys = ds_sharding(rng_keys)

inv_temperature = 1.25
for step in tqdm(range(SEQ_LEN)):
    logits = jitted_logits_fn(params, 
                              input_dict['inputs'],
                              input_dict['generation_result'])
    if step >= NCONSTITUENTS:
        valid_logits_mask = jnp.zeros_like(logits, dtype=jnp.float32).at[..., -NRG_VOCABS:-1].set(1)
    else:
        valid_logits_mask = jnp.zeros_like(logits, dtype=jnp.float32).at[..., 1:-NRG_VOCABS].set(1)
    logits += (-1e5) * (1.0 - valid_logits_mask)
    sampled_token, rng_keys = top_p_sampling_fn(logits * inv_temperature, rng_keys)
    input_dict['generation_result'] = input_dict['generation_result'].at[..., step].set(sampled_token[..., step])

##### resample radius of gyrations 
input_dict['generation_result'] = input_dict['generation_result'] + input_dict['inputs'] - 1
input_dict['inputs'] = jnp.ones_like(input_dict['inputs'], dtype=input_dict['inputs'].dtype)
input_dict['inputs'] = input_dict['inputs'].at[..., NCONSTITUENTS:].set(NATOMS + NRG_VOCABS) ### unk tokens for rg
input_dict['generation_result'] = input_dict['generation_result'].at[..., NCONSTITUENTS:].set(NATOMS + NRG_VOCABS)
inv_temperature = 1.25
for step in tqdm(range(NCONSTITUENTS, SEQ_LEN)):
    logits = jitted_logits_fn(params, 
                              input_dict['inputs'],
                              input_dict['generation_result'])
    if step >= NCONSTITUENTS:
        valid_logits_mask = jnp.zeros_like(logits, dtype=jnp.float32).at[..., -NRG_VOCABS:-1].set(1)
    else:
        valid_logits_mask = jnp.zeros_like(logits, dtype=jnp.float32).at[..., 1:-NRG_VOCABS].set(1)
    logits += (-1e5) * (1.0 - valid_logits_mask)
    sampled_token, rng_keys = top_p_sampling_fn(logits * inv_temperature, rng_keys)
    input_dict['generation_result'] = input_dict['generation_result'].at[..., step].set(sampled_token[..., step])

generation_result = np.array(input_dict['generation_result']) - 1
generation_result[..., :NCONSTITUENTS] = generation_result[..., :NCONSTITUENTS] - inputs[None, :NCONSTITUENTS]
constituents = []
for seq in tqdm(generation_result):
    #### decode constituents
    atomic_numbers_ = atomic_numbers[:]
    hydrogen_numbers_ = hydrogen_numbers[:]
    hybridizations_ = hybridizations[:]
    for token, num in zip(constituent_vocab_list, seq[:NCONSTITUENTS]):
        atomic_number, hydrogen_number, hybridization = tuple([int(x) for x in token.split('_')])
        atomic_numbers_ = atomic_numbers_ +  [atomic_number,] * num 
        hydrogen_numbers_ = hydrogen_numbers_ + [hydrogen_number,] * num 
        hybridizations_ = hybridizations_ + [hybridization,] * num
        
    #### decode rg
    rg_seq = seq[-NRG_TOKENS:] - NATOMS
    rg_ = np.exp(rg_seq[0]) * float("{}.{}".format(rg_seq[1], "".join([str(x) for x in rg_seq[2:]])))
    constituents.append(
        {"atomic_numbers": np.array(atomic_numbers_, dtype=np.uint8), 
         "hydrogen_numbers": np.array(hydrogen_numbers_, dtype=np.uint8),
         "hybridizations": np.array(hybridizations_, dtype=np.uint8), 
         "radius_of_gyrations": np.array([rg_,], dtype=np.float32)} # np.array([rg_,], dtype=np.float32)}
    )

### Structure Sampling

In [None]:
from inference.utils import preprocess_data

print("Preprocessing inputs")
input_dicts = [preprocess_data({**c, "bonds": bonds}, NATOMS) for c in tqdm(constituents)]
input_dict = {
    k: np.stack([d[k] for d in input_dicts]) for k in input_dicts[0].keys()
}

print("input shape & dtypes: ")
for k, v in input_dict.items():
    print("\t{} shape: {} dtype: {}".format(k, v.shape, v.dtype))

In [None]:
input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
#### JAX compiles a jitted function when you call it first time.
#### so it will be slow when you run this block first time.
structures, trajectories, rng_key = inference_fn(input_dict, rng_key)

#### save results 
with open(f'results/linker_design/result_0.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'constituents': constituents,
                           'bonds': bonds,
                           'trajectories': trajectories, 'structures': structures}), f)

### Graph Assembly

In [None]:
from graph_assembler.graph_assembler import assemble_mol_graph, check_bonds

structures = np.array(structures)
success_or_not = []
smileses = []
for i, (constituent, structure) in tqdm(enumerate(zip(constituents, structures))):
    success, Xponge_mol, smiles = assemble_mol_graph(constituent['atomic_numbers'], constituent['hydrogen_numbers'], structure)
    success_or_not.append(success) 
    smileses.append("" if not success else smiles)

### View Structures

In [None]:
#### view trajectories | structures
import MDAnalysis as mda 
import nglview as nv 

#### load your results 
with open(f'results/linker_design/result_1.pkl', 'rb') as f: 
    results = pkl.load(f)
    constituents = results['constituents']
    trajectories = results['trajectories']
    structures = results['structures']

elements = {
    6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 
    35: 'Br', 53: 'I'
}

mol_id = 3
atomic_numbers = constituents[mol_id]['atomic_numbers']
hydrogen_numbers = constituents[mol_id]['hydrogen_numbers']
n_atoms = len(atomic_numbers)

structure = np.array(structures)[mol_id, :n_atoms, :]
rg = np.sqrt(np.sum(structure ** 2) / n_atoms)
print("This is a molecule with {} atoms, rg = {:.2f} ang".format(n_atoms, rg))
print("SMILES: {}".format(smileses[mol_id]))
print("WARNING: bonds provided by NGLViewer may be problematic")

mol = mda.Universe.empty(n_atoms=len(atomic_numbers))
mol.add_TopologyAttr('names', ["{}{}".format(elements[int(n)], i) for i, n in enumerate(atomic_numbers)])
# mol.load_new(trajectory - np.mean(trajectory, axis=1, keepdims=True)) ### view trajectories 
mol.load_new(structure) ### view structures
view = nv.show_mdanalysis(mol)
view

## Building Structures for Free Energy Perturbations (FEP)

In [None]:
NSAMPLE_PER_DEVICE = 8
NSAMPLES = NSAMPLE_PER_DEVICE * NDEVICES

### Scaffold Hopping

In [None]:
### system in ref: (Scaffold Hopping and Optimization of Samll Molecule Soluble Adenyl Cyclase Inhibitors Led by Free Energy Perturbation)
### complex system: 5IV4

smi1 = 'ClC1=CC(N(CC2=CC=CS2)C3CC3)=NC(N)=N1'
smi2 = 'ClC1=CC(C2=NN(C)C=C2CC3=CC=CS3)=NC(N)=N1'

structure_1_in_complex = np.array(
    [
        [6.964, 27.228, -2.996],
        [8.493, 26.639, -2.312],
        [9.628, 27.353, -2.035],
        [10.739, 26.750, -1.419],
        [11.887, 27.354, -1.221],
        [13.063, 26.577, -0.917],
        [13.661, 25.909, -2.201],
        [14.581, 26.467, -2.990],
        [14.864, 25.596, -4.013],
        [14.151, 24.478, -4.116],
        [13.101, 24.433, -2.771],
        [12.182, 28.620, -1.867],
        [11.680, 29.889, -1.120],
        [13.167, 29.497, -1.232],
        [10.629, 25.468, -1.102],
        [9.563, 24.799, -1.444], 
        [9.578, 23.512, -1.223], 
        [8.480, 25.350, -1.996]
    ]
)

### we need to sample structure 2 consistent with structure 1 (common structures, two aromatic rings)
index_map_from_2_to_1 = {
    0:0, 1:1, 2:2, 3:3, 16:14, 17:15, 18:16, 19:17, 11:6, 12:7, 13:8, 14:9, 15:10
}

rg_1_in_complex = np.sqrt(
    np.mean(np.linalg.norm(structure_1_in_complex - np.mean(structure_1_in_complex, axis=0, keepdims=True), axis=-1)**2)
)

In [None]:
mol_with_index(Chem.MolFromSmiles(smi2))

In [None]:
mol_with_index(Chem.MolFromSmiles(smi1))

In [None]:
constituents_dict = SMILES_to_constituents(smi2)
atomic_numbers = constituents_dict['atomic_numbers']
hydrogen_numbers = constituents_dict['hydrogen_numbers']
hybridizations = constituents_dict['hybridizations']
bonds = constituents_dict['bonds']
print("atomic numbers:  ", ",".join([str(x) for x in atomic_numbers]))
print("hydrogen_numbers:", ",".join([str(x) for x in hydrogen_numbers]))
print("hybridizations:  ", ",".join([str(x) for x in hybridizations]))
for atom_a, bonds_a in bonds.items():
    if len(bonds_a.keys()) > 0:
        print("\t{}: ".format(atom_a), "{", ", ".join(["{}: {}".format(atom_b, btype) for atom_b, btype in bonds_a.items()]), "},")

#### Constituents Sampling

In [None]:
### sample rg (experiments show that using decoder to sample rg is more robust)
constituents_str = ["{}_{}_{}".format(i, j, k) for i, j, k in zip(atomic_numbers, 
                                                                        hydrogen_numbers, 
                                                                        hybridizations)]
inputs_constituents = jnp.array([constituents_str.count(t) for t in constituent_vocab_list], dtype=jnp.int32) + 1
input_dict = {
    "inputs": jnp.ones((NSAMPLES, SEQ_LEN), dtype=jnp.int32),
    "generation_result": jnp.zeros((NSAMPLES, SEQ_LEN), dtype=jnp.int32)
}

rng_keys = jax.random.split(rng_key, NSAMPLES*SEQ_LEN + 1)
rng_keys, rng_key = rng_keys[:NSAMPLES*SEQ_LEN].reshape(NSAMPLES, SEQ_LEN, -1), rng_keys[-1]

##### resample radius of gyrations 
input_dict['inputs'] = input_dict['inputs'].at[..., NCONSTITUENTS:].set(NATOMS + NRG_VOCABS) ### unk tokens for rg
input_dict['generation_result'] = input_dict['generation_result'].at[..., NCONSTITUENTS:].set(NATOMS + NRG_VOCABS) ### unk tokens for rg
input_dict['generation_result'] = input_dict['generation_result'].at[..., :NCONSTITUENTS].set(inputs_constituents[None, ...])

if SHARDING:
    #### shard inputs 
    ds_sharding = partial(_sharding, shards=global_sharding)
    input_dict = jax.tree_map(ds_sharding, input_dict)
    rng_keys = ds_sharding(rng_keys)

inv_temperature = 1.25
for step in tqdm(range(NCONSTITUENTS, SEQ_LEN)):
    logits = jitted_logits_fn(params, 
                              input_dict['inputs'],
                              input_dict['generation_result'])
    if step >= NCONSTITUENTS:
        valid_logits_mask = jnp.zeros_like(logits, dtype=jnp.float32).at[..., -NRG_VOCABS:-1].set(1)
    else:
        valid_logits_mask = jnp.zeros_like(logits, dtype=jnp.float32).at[..., 1:-NRG_VOCABS].set(1)
    logits += (-1e5) * (1.0 - valid_logits_mask)
    sampled_token, rng_keys = top_p_sampling_fn(logits * inv_temperature, rng_keys)
    input_dict['generation_result'] = input_dict['generation_result'].at[..., step].set(sampled_token[..., step])

generation_result = np.array(input_dict['generation_result']) - 1
generation_result[..., :NCONSTITUENTS] = generation_result[..., :NCONSTITUENTS] - inputs[None, :NCONSTITUENTS]
constituents = []
for seq in tqdm(generation_result):
    #### decode rg
    rg_seq = seq[-NRG_TOKENS:] - NATOMS
    rg_ = np.exp(rg_seq[0]) * float("{}.{}".format(rg_seq[1], "".join([str(x) for x in rg_seq[2:]])))
    constituents.append(
        {"atomic_numbers": np.array(atomic_numbers, dtype=np.uint8), 
         "hydrogen_numbers": np.array(hydrogen_numbers, dtype=np.uint8),
         "hybridizations": np.array(hybridizations, dtype=np.uint8), 
         "radius_of_gyrations": np.array([rg_1_in_complex, ], dtype=np.float32)} # np.array([rg_,], dtype=np.float32)} # np.array([rg_,], dtype=np.float32)}
    ) ### using rg of structure 1?

#### Structure Sampling

In [None]:
repaint_info = {
    "structure": np.stack([structure_1_in_complex[index_map_from_2_to_1[i]] if i in index_map_from_2_to_1.keys() else np.zeros(3) 
                                 for i in range(NATOMS)]).astype(np.float32),
    "mask": np.array([True if i in index_map_from_2_to_1.keys() else False for i in range(NATOMS)]).astype(np.bool_), 
}

from inference.utils import preprocess_data

print("Preprocessing inputs")
input_dicts = [preprocess_data({**c, "bonds": bonds}, NATOMS) for c in tqdm(constituents)]
input_dict = {
    k: np.stack([d[k] for d in input_dicts]) for k in input_dicts[0].keys()
}
repaint_dict = jax.tree_map(lambda x:np.repeat(x[None, ...], 
                                               NSAMPLE_PER_DEVICE * NDEVICES, axis=0), repaint_info)

print("input shape & dtypes: ")
for k, v in input_dict.items():
    print("\t{} shape: {} dtype: {}".format(k, v.shape, v.dtype))
for k, v in repaint_dict.items():
    print("\tRePaint {} shape: {} dtype: {}".format(k, v.shape, v.dtype))
    
input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
repaint_dict = jax.tree_map(lambda x:jnp.array(x), repaint_dict)
#### JAX compiles a jitted function when you call it first time.
#### so it will be slow when you run this block first time.
structures, trajectories, rng_key = inference_fn(input_dict, rng_key, repaint_dict=repaint_dict)

#### save results 
with open(f'results/FEP_structure_build/scaffold_hopping.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'constituents': constituents,
                           'repaint_info': repaint_info,
                           'bonds': bonds,
                           'trajectories': trajectories, 'structures': structures}), f)

#### View Structures

In [None]:
#### view trajectories | structures
import MDAnalysis as mda 
import nglview as nv 

#### load your results 
with open(f'results/FEP_structure_build/scaffold_hopping.pkl', 'rb') as f: 
    results = pkl.load(f)
    constituents = results['constituents']
    trajectories = results['trajectories']
    structures = results['structures']

elements = {
    6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 
    35: 'Br', 53: 'I'
}

mol_id = 3
atomic_numbers = constituents[mol_id]['atomic_numbers']
hydrogen_numbers = constituents[mol_id]['hydrogen_numbers']
n_atoms = len(atomic_numbers)

structure = np.array(structures)[mol_id, :n_atoms, :]
rg = np.sqrt(np.sum(structure ** 2) / n_atoms)
print("This is a molecule with {} atoms, rg = {:.2f} ang".format(n_atoms, rg))
print("SMILES: {}".format(smileses[mol_id]))
print("WARNING: bonds provided by NGLViewer may be problematic")

mol = mda.Universe.empty(n_atoms=len(atomic_numbers))
mol.add_TopologyAttr('names', ["{}{}".format(elements[int(n)], i) for i, n in enumerate(atomic_numbers)])
# mol.load_new(trajectory - np.mean(trajectory, axis=1, keepdims=True)) ### view trajectories 
mol.load_new(structure) ### view structures
view = nv.show_mdanalysis(mol)
view

### R-group modification

#### Display R-group modifications

In [None]:
index_maps = {} #### dictionary to store fixed atoms
### smi2: 'ClC1=CC(C2=NN(C)C=C2CC3=CC=CS3)=NC(N)=N1'
constituents_dict = SMILES_to_constituents(smi2)
atomic_numbers = constituents_dict['atomic_numbers']
hydrogen_numbers = constituents_dict['hydrogen_numbers']
hybridizations = constituents_dict['hybridizations']
bonds = constituents_dict['bonds']
print("atomic numbers:  ", ",".join([str(x) for x in atomic_numbers]))
print("hydrogen_numbers:", ",".join([str(x) for x in hydrogen_numbers]))
print("hybridizations:  ", ",".join([str(x) for x in hybridizations]))
for atom_a, bonds_a in bonds.items():
    if len(bonds_a.keys()) > 0:
        print("\t{}: ".format(atom_a), "{", ", ".join(["{}: {}".format(atom_b, btype) for atom_b, btype in bonds_a.items()]), "},")

index_maps[smi2] = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 11:11, 16:16, 17:17, 18:18, 19:19}

In [None]:
### modification 1: 
smi_2_1 = 'ClC1=CC(C2=NN(C)C=C2CC3=CC=CC=C3)=NC(N)=N1'
constituents_dict = SMILES_to_constituents(smi_2_1)
atomic_numbers = constituents_dict['atomic_numbers']
hydrogen_numbers = constituents_dict['hydrogen_numbers']
hybridizations = constituents_dict['hybridizations']
bonds = constituents_dict['bonds']
print("atomic numbers:  ", ",".join([str(x) for x in atomic_numbers]))
print("hydrogen_numbers:", ",".join([str(x) for x in hydrogen_numbers]))
print("hybridizations:  ", ",".join([str(x) for x in hybridizations]))
for atom_a, bonds_a in bonds.items():
    if len(bonds_a.keys()) > 0:
        print("\t{}: ".format(atom_a), "{", ", ".join(["{}: {}".format(atom_b, btype) for atom_b, btype in bonds_a.items()]), "},")
index_maps[smi_2_1] = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 11:11, 17:16, 18:17, 19:18, 20:19}
mol_with_index(Chem.MolFromSmiles(smi_2_1))

In [None]:
### modification 2: 
smi_2_2 = 'ClC1=CC(C2=NN(C(F)(F)F)C=C2CC3=CC=CC=C3)=NC(N)=N1'
constituents_dict = SMILES_to_constituents(smi_2_2)
atomic_numbers = constituents_dict['atomic_numbers']
hydrogen_numbers = constituents_dict['hydrogen_numbers']
hybridizations = constituents_dict['hybridizations']
bonds = constituents_dict['bonds']
print("atomic numbers:  ", ",".join([str(x) for x in atomic_numbers]))
print("hydrogen_numbers:", ",".join([str(x) for x in hydrogen_numbers]))
print("hybridizations:  ", ",".join([str(x) for x in hybridizations]))
for atom_a, bonds_a in bonds.items():
    if len(bonds_a.keys()) > 0:
        print("\t{}: ".format(atom_a), "{", ", ".join(["{}: {}".format(atom_b, btype) for atom_b, btype in bonds_a.items()]), "},")
index_maps[smi_2_2] = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 11:8, 12:9, 13:10, 14:11, 20:16, 21:17, 22:18, 23:19}
mol_with_index(Chem.MolFromSmiles(smi_2_2))

In [None]:
### modification 3: 
smi_2_3 = 'ClC1=CC(C2=NN(C)C=C2CC3=NC=CS3)=NC(N)=N1'
constituents_dict = SMILES_to_constituents(smi_2_3)
atomic_numbers = constituents_dict['atomic_numbers']
hydrogen_numbers = constituents_dict['hydrogen_numbers']
hybridizations = constituents_dict['hybridizations']
bonds = constituents_dict['bonds']
print("atomic numbers:  ", ",".join([str(x) for x in atomic_numbers]))
print("hydrogen_numbers:", ",".join([str(x) for x in hydrogen_numbers]))
print("hybridizations:  ", ",".join([str(x) for x in hybridizations]))
for atom_a, bonds_a in bonds.items():
    if len(bonds_a.keys()) > 0:
        print("\t{}: ".format(atom_a), "{", ", ".join(["{}: {}".format(atom_b, btype) for atom_b, btype in bonds_a.items()]), "},")
index_maps[smi_2_3] = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 11:11, 16:16, 17:17, 18:18, 19:19}
mol_with_index(Chem.MolFromSmiles(smi_2_3))

In [None]:
### modification 4: 
smi_2_4 = 'ClC1=CC(C2=NN(C)C=C2CN(C(C)=O)C3CCC3)=NC(N)=N1'
constituents_dict = SMILES_to_constituents(smi_2_4)
atomic_numbers = constituents_dict['atomic_numbers']
hydrogen_numbers = constituents_dict['hydrogen_numbers']
hybridizations = constituents_dict['hybridizations']
bonds = constituents_dict['bonds']
print("atomic numbers:  ", ",".join([str(x) for x in atomic_numbers]))
print("hydrogen_numbers:", ",".join([str(x) for x in hydrogen_numbers]))
print("hybridizations:  ", ",".join([str(x) for x in hybridizations]))
for atom_a, bonds_a in bonds.items():
    if len(bonds_a.keys()) > 0:
        print("\t{}: ".format(atom_a), "{", ", ".join(["{}: {}".format(atom_b, btype) for atom_b, btype in bonds_a.items()]), "},")
index_maps[smi_2_4] = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 11:11, 19:16, 20:17, 21:18, 22:19}
mol_with_index(Chem.MolFromSmiles(smi_2_4))

In [None]:
### modification 5: 
smi_2_5 = 'ClC1=CC(C2=NN(C)C(CO)=C2CC3=CC=CC=C3)=NC(N)=N1'
constituents_dict = SMILES_to_constituents(smi_2_5)
atomic_numbers = constituents_dict['atomic_numbers']
hydrogen_numbers = constituents_dict['hydrogen_numbers']
hybridizations = constituents_dict['hybridizations']
bonds = constituents_dict['bonds']
print("atomic numbers:  ", ",".join([str(x) for x in atomic_numbers]))
print("hydrogen_numbers:", ",".join([str(x) for x in hydrogen_numbers]))
print("hybridizations:  ", ",".join([str(x) for x in hybridizations]))
for atom_a, bonds_a in bonds.items():
    if len(bonds_a.keys()) > 0:
        print("\t{}: ".format(atom_a), "{", ", ".join(["{}: {}".format(atom_b, btype) for atom_b, btype in bonds_a.items()]), "},")
index_maps[smi_2_5] = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 11:9, 12:10, 13:11, 19:16, 20:17, 21:18, 22:19}
mol_with_index(Chem.MolFromSmiles(smi_2_5))

In [None]:
### structures of smi_2:
with open(f'results/FEP_structure_build/scaffold_hopping.pkl', 'rb') as f:
    results = pkl.load(f)
    structure_template = results['structures'][0][:len(results['constituents'][0]['atomic_numbers'])]
rg = np.sqrt(np.mean(np.linalg.norm(structure_template - np.mean(structure_template, axis=0, keepdims=True), axis=-1)**2))

constituents, bonds, input_dicts, repaint_infos = [], [], [], []
NSAMPLE_PER_SYS = 8
for smi, index_map in index_maps.items():
    constituents_dict = SMILES_to_constituents(smi)
    constituents.extend(
        [{'atomic_numbers': constituents_dict['atomic_numbers'], 
         'hydrogen_numbers': constituents_dict['hydrogen_numbers'], 
         'hybridizations': constituents_dict['hybridizations'], 
         'radius_of_gyrations': [rg, ]}] * NSAMPLE_PER_SYS
    )
    bonds.extend([constituents_dict['bonds'],]*NSAMPLE_PER_SYS)
    input_dicts.extend([preprocess_data({**constituents_dict, 'radius_of_gyrations': [rg,]}, NATOMS),]*NSAMPLE_PER_SYS)

    repaint_info = {
        "structure": np.stack([structure_template[index_map[i]] if i in index_map.keys() else np.zeros(3) 
                                     for i in range(NATOMS)]).astype(np.float32),
        "mask": np.array([True if i in index_map.keys() else False for i in range(NATOMS)]).astype(np.bool_), 
    }
    repaint_infos.extend([repaint_info,]*NSAMPLE_PER_SYS)

constituents = constituents + [constituents[-1],] * (NSAMPLES - len(constituents))
bonds = bonds + [bonds[-1],] * (NSAMPLES - len(bonds))
input_dicts = input_dicts + [input_dicts[-1],] * (NSAMPLES - len(input_dicts))
repaint_infos = repaint_infos + [repaint_infos[-1],] * (NSAMPLES - len(repaint_infos))

#### Structure Sampling

In [None]:
input_dict = {
    k: np.stack([d[k] for d in input_dicts]) for k in input_dicts[0].keys()
}
repaint_dict = {
    k: np.stack([d[k] for d in repaint_infos]) for k in repaint_infos[0].keys()
}

print("input shape & dtypes: ")
for k, v in input_dict.items():
    print("\t{} shape: {} dtype: {}".format(k, v.shape, v.dtype))
for k, v in repaint_dict.items():
    print("\tRePaint {} shape: {} dtype: {}".format(k, v.shape, v.dtype))
    
input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
repaint_dict = jax.tree_map(lambda x:jnp.array(x), repaint_dict)

In [None]:
#### JAX compiles a jitted function when you call it first time.
#### so it will be slow when you run this block first time.
structures, trajectories, rng_key = inference_fn(input_dict, rng_key, repaint_dict=repaint_dict)

#### save results 
with open(f'results/FEP_structure_build/R_group_modifications.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'constituents': constituents,
                           'repaint_info': repaint_info,
                           'bonds': bonds,
                           'trajectories': trajectories, 'structures': structures}), f)

#### View Structures

In [None]:
#### view trajectories | structures
import MDAnalysis as mda 
import nglview as nv 

#### load your results 
with open(f'results/FEP_structure_build/R_group_modifications.pkl', 'rb') as f: 
    results = pkl.load(f)
    constituents = results['constituents']
    trajectories = results['trajectories']
    structures = results['structures']

elements = {
    6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 
    35: 'Br', 53: 'I'
}

mol_id = 3
atomic_numbers = constituents[mol_id]['atomic_numbers']
hydrogen_numbers = constituents[mol_id]['hydrogen_numbers']
n_atoms = len(atomic_numbers)

structure = np.array(structures)[mol_id, :n_atoms, :]
rg = np.sqrt(np.sum(structure ** 2) / n_atoms)
print("This is a molecule with {} atoms, rg = {:.2f} ang".format(n_atoms, rg))
print("WARNING: bonds provided by NGLViewer may be problematic")

mol = mda.Universe.empty(n_atoms=len(atomic_numbers))
mol.add_TopologyAttr('names', ["{}{}".format(elements[int(n)], i) for i, n in enumerate(atomic_numbers)])
# mol.load_new(trajectory - np.mean(trajectory, axis=1, keepdims=True)) ### view trajectories 
mol.load_new(structure) ### view structures
view = nv.show_mdanalysis(mol)
view

## Lead-Imprinted Binder Design

### Lead Template (PDBID: 5A9U)

In [None]:
from rdkit import Chem 

sys_name = '5A9U'
suppl = Chem.SDMolSupplier(f'moledit_dataset/lead_imprinted_binder_design/{sys_name}/5a9u_ligand.sdf')
mol = suppl[0]

mol

In [None]:
constituents_dict_template, structure_template = RDMol_to_constituents_and_structure(mol)
rg_template = np.sqrt(
    np.mean(np.linalg.norm(structure_template - np.mean(structure_template, axis=0, keepdims=True), axis=-1)**2)
)
n_atoms_template = len(structure_template)
atomic_numbers = constituents_dict_template['atomic_numbers']
hydrogen_numbers = constituents_dict_template['hydrogen_numbers']
hybridizations = constituents_dict_template['hybridizations']
bonds = constituents_dict_template['bonds']
print("atomic numbers:  ", ",".join([str(x) for x in atomic_numbers]))
print("hydrogen_numbers:", ",".join([str(x) for x in hydrogen_numbers]))
print("hybridizations:  ", ",".join([str(x) for x in hybridizations]))

### Constituents Sampling

In [None]:
#### Sample atom constituents from datasets (ZINC-300k)
with open('moledit_dataset/constituents/constituents_ZINC_300k.pkl', 'rb') as f:
    constituents_data = pkl.load(f)

#### or alternatively, you can sample from ZINC-3m, ZINC-30m, or from constituents model

In [None]:
#### select n_atoms approx n_atoms in templates
constituents_all = []
for smi, c in tqdm(constituents_data.items()):
    c.update({'radius_of_gyrations': [rg_template, ]})
    if np.abs(len(c['atomic_numbers']) - n_atoms_template) < 5:
        constituents_all.append(c)

random_idx = np.random.randint(0, len(constituents_all), NSAMPLES)
constituents = [constituents_all[i] for i in random_idx]

In [None]:
print("Example constituents: ")
print("\tatomic numbers: ", constituents[0]['atomic_numbers'])
print("\thydrogen numbers: ", constituents[0]['hydrogen_numbers'])
print("\thybridizaions: ", constituents[0]['hybridizations'])
print("\tradius of gyrations: ", constituents[0]['radius_of_gyrations'])
print("\t**REMARK**: hybridization symbols are same with RDkit")

from inference.utils import preprocess_data

print("Preprocessing inputs")
input_dicts = [preprocess_data(c, NATOMS) for c in tqdm(constituents)]
input_dict = {
    k: np.stack([d[k] for d in input_dicts]) for k in input_dicts[0].keys()
}

print("input shape & dtypes: ")
for k, v in input_dict.items():
    print("\t{} shape: {} dtype: {}".format(k, v.shape, v.dtype))

In [None]:
shape_template_dict = {
    'template_structure': np.pad(structure_template, ((0, NATOMS-n_atoms_template), (0, 0))), 
    'template_atom_mask': np.pad(np.ones(n_atoms_template, dtype=np.bool_), (0, NATOMS-n_atoms_template)),
    'template_coeff': 0.0,
}

shape_template_dict['template_structure'] = np.tile(shape_template_dict['template_structure'][None, ...], 
                                                    (NSAMPLES, 1, 1))
shape_template_dict['template_atom_mask'] = np.tile(shape_template_dict['template_atom_mask'][None, ...], 
                                                    (NSAMPLES, 1))

### Structure Sampling

In [None]:
NSAMPLE_PER_DEVICE = 8 # 128
NSAMPLES = NSAMPLE_PER_DEVICE * NDEVICES
NATOMS = 64 
INFERENCE_METHOD = "DPM_3"

#### jit and vmap functions
def score_forward_fn(atom_feat, bond_feat, x, atom_mask, sigma, rg, gamma=1.0):
    cond_list = [sigma < noise_thresholds[0],] + \
                [jnp.logical_and(sigma >= noise_thresholds[i], sigma < noise_thresholds[i+1]) for i in range(0, len(noise_thresholds) - 1)] + \
                [sigma >= noise_thresholds[-1],]
    value_list = [net.apply(
                    {}, atom_feat, bond_feat, x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value_unc_list = [net.apply(
                    {}, atom_feat, jnp.zeros_like(bond_feat), x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value = gamma * jnp.array(value_list, jnp.float32) +\
                (1.0 - gamma) * jnp.array(value_unc_list, jnp.float32)
    
    return jnp.sum(jnp.array(cond_list, dtype=jnp.float32)[..., None, None] * value, axis=0)

score_forward_fn_jvj = jax.jit(jax.vmap(jax.jit(score_forward_fn)))
if INFERENCE_METHOD == "DPM_3":
    inference_fn = partial(DPM_3_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "DPM_pp_2S":
    inference_fn = partial(DPM_pp_2S_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "Langevin":
    inference_fn = partial(Langevin_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=1000, shard_inputs=SHARDING)

In [None]:
shape_template_coeff = 32.0 ### This coeficient controls the additional shape-gradient term in SDE

input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
shape_template_dict['template_coeff'] = shape_template_coeff 
shape_template_dict = jax.tree_map(lambda x:jnp.array(x), shape_template_dict)
#### JAX compiles a jitted function when you call it first time.
#### so it will be slow when you run this block first time.
structures, trajectories, rng_key = inference_fn(input_dict, rng_key, n_steps=20, shape_dict=shape_template_dict)

In [None]:
#### save results 
with open(f'results/lead_imprinted_binder_design/{sys_name}/results_coef_{shape_template_coeff}.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'constituents': constituents,
                           'template': {'template_structure': structure_template, 'template_constituents': constituents_dict_template},
                           'trajectories': trajectories, 'structures': structures}), f)

### Graph Assembly

In [None]:
from graph_assembler.graph_assembler import assemble_mol_graph, check_bonds

structures = np.array(structures)
success_or_not = []
smileses = []
for i, (constituent, structure) in tqdm(enumerate(zip(constituents, structures))):
    success, Xponge_mol, smiles = assemble_mol_graph(constituent['atomic_numbers'], constituent['hydrogen_numbers'], structure)
    success_or_not.append(success) 
    smileses.append("" if not success else smiles)

### View Structures

In [None]:
#### view trajectories | structures
import MDAnalysis as mda 
import nglview as nv 

#### load your results 
with open(f'results/lead_imprinted_binder_design/{sys_name}/results_coef_{shape_template_coeff}.pkl', 'rb') as f: 
    results = pkl.load(f)
    constituents = results['constituents']
    trajectories = results['trajectories']
    structures = results['structures']

elements = {
    6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 
    35: 'Br', 53: 'I'
}

mol_id = 0
atomic_numbers = constituents[mol_id]['atomic_numbers']
hydrogen_numbers = constituents[mol_id]['hydrogen_numbers']
n_atoms = len(atomic_numbers)

structure = np.array(structures)[mol_id, :n_atoms, :]
structure = structure - np.mean(structure, axis=0, keepdims=True)
rg = np.sqrt(np.sum(structure ** 2) / n_atoms)
print("This is a molecule with {} atoms, rg = {:.2f}/{:.2f} ang".format(n_atoms, rg, constituents[mol_id]['radius_of_gyrations'][0]))
print("SMILES: {}".format(smileses[mol_id]))
print("WARNING: bonds provided by NGLViewer may be problematic")

mol = mda.Universe.empty(n_atoms=len(atomic_numbers))
mol.add_TopologyAttr('names', ["{}{}".format(elements[int(n)], hydrogen_numbers[i]) for i, n in enumerate(atomic_numbers)])
# mol.load_new(trajectory - np.mean(trajectory, axis=1, keepdims=True)) ### view trajectories 
mol.load_new(structure) ### view structures
view = nv.show_mdanalysis(mol)
view