In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
import sys 
sys.path.append(".")
import math
import pickle as pkl 
import numpy as np
from tqdm import tqdm

print("We are now running Python in: ", sys.path)

## Preparation (nets, constants, params and functional utils)

In [None]:
import jax 
import jax.numpy as jnp

from functools import partial
from cybertron.common.config_load import load_config

#### load nets
from train.train import MolEditScoreNet
from cybertron.model.molct_plus import AdaLNMolCT_Plus
from cybertron.readout import AdaLNGFNReadout

from train.utils import set_dropout_rate_config
from jax.sharding import PositionalSharding

def _sharding(input, shards):

    n_device = shards.shape[0]
    if isinstance(input, (np.ndarray, jax.Array)):
        _shape = [n_device, ] + [1 for _ in range(input.ndim - 1)]
        return jax.device_put(input, shards.reshape(_shape))
    elif input is None:
        return jax.device_put(input, shards)
    else:
        raise TypeError(f"Invalid input: {input}")

from inference.inference import DPM_3_inference, Langevin_inference, DPM_pp_2S_inference

In [None]:
NDEVICES = 1
NATOMS = 64
SHARDING = True #### you can use multiple devices
if SHARDING:
    NDEVICES = len(jax.devices())
    print("{} DEVICES detected: {}".format(NDEVICES, jax.devices()))

def split_rngs(rng_key, shape):
    size = np.prod(shape)
    rng_keys = jax.random.split(rng_key, size + 1)
    return rng_keys[:-1].reshape(shape + (-1,)), rng_keys[-1]

rng_key = jax.random.PRNGKey(8888) #### set your random seed here
np.random.seed(7777)

In [None]:
##### initialize models
encoder_config = load_config("config/molct_plus.yaml")
gfn_config = load_config("config/gfn.yaml")
gfn_config.model.num_atoms = 128
gfn_n_interactions = [4, 4, 3]

modules = []
for n_inter in gfn_n_interactions:
    gfn_config.settings.n_interactions = n_inter
    modules.append({
        "encoder": {"module": AdaLNMolCT_Plus, 
                    "args": {"config": encoder_config}},
        "gfn": {"module": AdaLNGFNReadout, 
                "args": {"config": gfn_config}}
    })

##### load params
load_ckpt_paths = ['./params/QMugs/structure_model/moledit_params_track1.pkl', 
                   './params/QMugs/structure_model/moledit_params_track2.pkl',
                   './params/QMugs/structure_model/moledit_params_track3.pkl'] 
noise_thresholds = [0.35, 1.95]

params = []
for path in load_ckpt_paths:
    with open(path, 'rb') as f: 
        params.append(pkl.load(f))
    
if SHARDING:
    ##### replicate params
    global_sharding = PositionalSharding(jax.devices()).reshape(NDEVICES, 1)
    params = jax.device_put(params, global_sharding.replicate())

for param, module in zip(params, modules):
    for k, v in module.items():
        module[k]['args']['config'] = \
            set_dropout_rate_config(module[k]['args']['config'], 0.0)
        module[k]["module"] = v["module"](**v["args"])
        partial_params = {"params": param["params"]['score_net'].pop(k)}
        module[k]["callable_fn"] = partial(module[k]["module"].apply, partial_params)

moledit_scorenets = [MolEditScoreNet(
        encoder=modules[k]['encoder']['callable_fn'],
        gfn=modules[k]['gfn']['callable_fn'],
        with_cond = True,
    ) for k in range(len(load_ckpt_paths))]

In [None]:
#### Sample atom constituents from constituent model
from config.transformer_config import transformer_config
from transformer.prefix_model import Transformer, TransformerConfig

with open('params/QMugs/constituents_model/constituents_vocab.pkl', 'rb') as f:
    constituent_vocab_list = pkl.load(f)
NCONSTITUENTS = len(constituent_vocab_list) # 30
NRG_TOKENS = 3 # seq_len = 30 + 3
SEQ_LEN = NCONSTITUENTS + NRG_TOKENS

NRG_VOCABS = 11
transformer_config.deterministic = True
transformer_config.dtype = jnp.float32
transformer = Transformer(
    config=TransformerConfig(
            **{
                **transformer_config,
                "vocab_size": 64 + NRG_VOCABS + 1, 
                "output_vocab_size": 64 + NRG_VOCABS + 1}, )
)

##### load params
with open("params/QMugs/constituents_model/moledit_params.pkl", "rb") as f:
    params = pkl.load(f)
    params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)

if SHARDING:
    ##### replicate params
    global_sharding = PositionalSharding(jax.devices()).reshape(NDEVICES, 1)
    params = jax.device_put(params, global_sharding.replicate())

def top_p_sampling(logits, rng_key, p=0.9):
    sorted_indices = jnp.argsort(logits)
    sorted_logits = logits[sorted_indices]
    sorted_probs = jax.nn.softmax(sorted_logits)
    cum_probs = jnp.cumsum(sorted_probs)
    invalid_mask = cum_probs < (1-p)
    
    rng_key, sample_key = jax.random.split(rng_key)
    sampled_token = jax.random.categorical(sample_key, sorted_logits+invalid_mask.astype(jnp.float32)*(-1e5))
    
    return sorted_indices[sampled_token], rng_key 
        
##### prepare functions, jit & vmap
jitted_logits_fn = jax.jit(transformer.apply)
top_p_sampling_fn = jax.vmap(jax.vmap(jax.jit(partial(top_p_sampling, p=0.9))))

## Property-guided Sampling

In [None]:
NSAMPLE_PER_DEVICE = 8 # 128
NSAMPLES = NSAMPLE_PER_DEVICE * NDEVICES

### Sample Properties

In [None]:
required_properties = [
    'mw', 'rotatable_bonds', 'rings', 'hbond_acceptors', 'hbond_donors', 
    'LogP', 'TPSA', 'DFT_DIPOLE_TOT', 'DFT_HOMO_LUMO_GAP'
]

In [None]:
#### Sample properties from datasets
with open("./moledit_dataset/property/QMugs.pkl", "rb") as f:
    property_dict = pkl.load(f)

properties = [
    {p: property_dict[k][p] for p in required_properties} for  k in 
    np.random.choice(list(property_dict.keys()), NSAMPLE_PER_DEVICE * NDEVICES)
]

### Constituents Sampling

In [None]:
from data.constants import property_info

def process_a_property(data_property, property_mask_type="no_mask"):
    prefix_len = len(required_properties)
    dim_prefix_emb = len(property_info[required_properties[0]]['rbf_centers'])
    
    if property_mask_type == 'all_mask':
        prefix = np.zeros((prefix_len, dim_prefix_emb), dtype=np.float32)
        prefix_mask = np.zeros(len(required_properties), dtype=np.bool_)
        
        return {'prefix': prefix, 'prefix_mask': prefix_mask}

    mask_prob = 0.0 if property_mask_type == 'no_mask' else 0.5
    prefix = []
    prefix_mask = []
    for p in required_properties:
        if np.random.rand() < mask_prob:
            prefix.append(np.zeros(dim_prefix_emb, dtype=np.float32))
            prefix_mask.append(False)
        else:
            p_val = data_property[p]
            p_val = (p_val - property_info[p]['mean']) / property_info[p]['std']
            rbf_centers = property_info[p]['rbf_centers']
            rbf_sigma = property_info[p]['rbf_sigma']
            prefix.append(1.0 / np.sqrt(2 * np.pi * rbf_sigma) * np.exp(-0.5 * ((p_val - rbf_centers) / rbf_sigma) ** 2))
            prefix_mask.append(True)
            
    return {
        'prefix': np.array(prefix, dtype=np.float32), 
        'prefix_mask': np.array(prefix_mask, dtype=np.bool_)
    }

In [None]:
prefix_dicts = [process_a_property(p) for p in properties]
prefix_dict = {k: np.array([d[k] for d in prefix_dicts]) for k in ['prefix', 'prefix_mask']}

for k, v in prefix_dict.items(): print(k, v.shape, v.dtype)

In [None]:
input_dict = {
    "inputs": jnp.ones((NSAMPLES, SEQ_LEN), dtype=jnp.int32), 
    "prefix": jnp.array(prefix_dict["prefix"]), 
    "prefix_mask": jnp.array(prefix_dict["prefix_mask"]),
    "generation_result": jnp.ones((NSAMPLES, SEQ_LEN), dtype=jnp.int32)
}
input_dict["inputs"] = input_dict["inputs"].at[:, NCONSTITUENTS:].set(64 + NRG_VOCABS) #### unk token for rg

rng_keys = jax.random.split(rng_key, NSAMPLES*SEQ_LEN + 1)
rng_keys, rng_key = rng_keys[:NSAMPLES*SEQ_LEN].reshape(NSAMPLES, SEQ_LEN, -1), rng_keys[-1]

if SHARDING:
    #### shard inputs 
    ds_sharding = partial(_sharding, shards=global_sharding)
    input_dict = jax.tree_map(ds_sharding, input_dict)
    rng_keys = ds_sharding(rng_keys)

inv_temperature = 1.25
for step in tqdm(range(SEQ_LEN)):
    logits = jitted_logits_fn(params, 
                              input_dict['inputs'],
                              input_dict['generation_result'], 
                              input_dict['prefix'], 
                              input_dict['prefix_mask'])
    if step >= NCONSTITUENTS:
        valid_logits_mask = jnp.zeros_like(logits, dtype=jnp.float32).at[..., -NRG_VOCABS:-1].set(1)
    else:
        valid_logits_mask = jnp.zeros_like(logits, dtype=jnp.float32).at[..., 1:-NRG_VOCABS].set(1)
    logits += (-1e5) * (1.0 - valid_logits_mask)
    sampled_token, rng_keys = top_p_sampling_fn(logits * inv_temperature, rng_keys)
    input_dict['generation_result'] = input_dict['generation_result'].at[..., step].set(sampled_token[..., step])

generation_result = np.array(input_dict['generation_result']) - 1
constituents = []
for seq in tqdm(generation_result):
    atomic_numbers, hydrogen_numbers, hybridizations = [], [], []
    n_atoms = 0
    #### decode constituents
    for token, num in zip(constituent_vocab_list, seq[:NCONSTITUENTS]):
        atomic_number, hydrogen_number, hybridization = tuple([int(x) for x in token.split('_')])
        atomic_numbers += [atomic_number,] * num 
        hydrogen_numbers += [hydrogen_number,] * num 
        hybridizations += [hybridization,] * num
        n_atoms += num 
        
    #### decode rg
    rg_seq = seq[-NRG_TOKENS:] - 64
    # print(rg_seq)
    rg = np.exp(rg_seq[0]) * float("{}.{}".format(rg_seq[1], "".join([str(x) for x in rg_seq[2:]])))
    constituents.append(
        {"atomic_numbers": np.array(atomic_numbers, dtype=np.uint8), 
         "hydrogen_numbers": np.array(hydrogen_numbers, dtype=np.uint8),
         "hybridizations": np.array(hybridizations, dtype=np.uint8), 
         "radius_of_gyrations": np.array([rg], dtype=np.float32)}
    )

In [None]:
### attach properties 
for c, p in zip(constituents, properties): c.update(p)

In [None]:
from rdkit import Chem 
import matplotlib.pyplot as plt

PeriodicTable = Chem.GetPeriodicTable()

def calculate_mw(c):
    mw = np.sum([PeriodicTable.GetAtomicWeight(PeriodicTable.GetElementSymbol(int(x))) for x in c['atomic_numbers']])
    mw += np.sum(c['hydrogen_numbers']) * PeriodicTable.GetAtomicWeight('H')
    return mw

mw_x = [p['mw'] for p in properties]
mw_y = [calculate_mw(c) for c in constituents]
plt.scatter(mw_x, mw_y)

### Structure Sampling

In [None]:
NATOMS = 128
INFERENCE_METHOD = "DPM_3"

#### jit and vmap functions
def score_forward_fn(atom_feat, bond_feat, x, atom_mask, sigma, rg, prop):
    cond_list = [sigma < noise_thresholds[0],] + \
                [jnp.logical_and(sigma >= noise_thresholds[i], sigma < noise_thresholds[i+1]) for i in range(0, len(noise_thresholds) - 1)] + \
                [sigma >= noise_thresholds[-1],]
    value_list = [net.apply(
                    {}, atom_feat, bond_feat, x, atom_mask, sigma, rg, prop)[-1] for net in moledit_scorenets]
    
    return jnp.sum(jnp.array(cond_list, dtype=jnp.float32)[..., None, None] * \
                    jnp.array(value_list, jnp.float32), axis=0)

score_forward_fn_jvj = jax.jit(jax.vmap(jax.jit(score_forward_fn)))
if INFERENCE_METHOD == "DPM_3":
    inference_fn = partial(DPM_3_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "Langevin":
    inference_fn = partial(Langevin_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=1000, shard_inputs=SHARDING)

In [None]:
print("Example constituents: ")
print("\tatomic numbers: ", constituents[0]['atomic_numbers'])
print("\thydrogen numbers: ", constituents[0]['hydrogen_numbers'])
print("\thybridizaions: ", constituents[0]['hybridizations'])
print("\t**REMARK**: hybridization symbols are same with RDkit")

from inference.utils import preprocess_data_with_property

print("Preprocessing inputs")
input_dicts = [preprocess_data_with_property(c, NATOMS, properties=[
    'rotatable_bonds', 'rings',
    'LogP', 'TPSA', 'DFT_DIPOLE_TOT', 'DFT_HOMO_LUMO_GAP']) for c in tqdm(constituents)]
input_dict = {
    k: np.stack([d[k] for d in input_dicts]) for k in input_dicts[0].keys()
}

print("input shape & dtypes: ")
for k, v in input_dict.items():
    print("\t{} shape: {} dtype: {}".format(k, v.shape, v.dtype))

In [None]:
input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
#### JAX compiles a jitted function when you call it first time.
#### so it will be slow when you run this block first time.
structures, trajectories, rng_key = inference_fn(input_dict, rng_key)

#### save results 
with open('results/property_guidance/property_guided_sampling.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'constituents': constituents, 'trajectories': trajectories, 'structures': structures}), f)

### Graph Assembly

In [None]:
import Xponge
from graph_assembler.graph_assembler import assemble_mol_graph

success_or_not = []
smileses = []
for i, (atomic_numbers, hydrogen_numbers, structure) in tqdm(enumerate(zip([c['atomic_numbers'] for c in constituents],
                                                                           [c['hydrogen_numbers'] for c in constituents],
                                                                           structures))):
    success, Xponge_mol, smiles = assemble_mol_graph(atomic_numbers, hydrogen_numbers, structure)
    success_or_not.append(success) 
    smileses.append("" if not success else smiles)

### View Structures

In [None]:
#### view trajectories | structures
import MDAnalysis as mda 
import nglview as nv 

#### load your results 
with open('results/property_guidance/property_guided_sampling.pkl', 'rb') as f: 
    results = pkl.load(f)
    constituents = results['constituents']
    trajectories = results['trajectories']
    structures = results['structures']

elements = {
    6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 
    35: 'Br', 53: 'I'
}

mol_id = 10
atomic_numbers = constituents[mol_id]['atomic_numbers']
hydrogen_numbers = constituents[mol_id]['hydrogen_numbers']
n_atoms = len(atomic_numbers)
trajectory = np.array(trajectories)[:, mol_id, :n_atoms, :]
structure = np.array(structures)[mol_id, :n_atoms, :]
rg = np.sqrt(np.sum(structure ** 2) / n_atoms)
print("This is a molecule with {} atoms, rg = {:.2f}/{:.2f} ang".format(n_atoms, constituents[mol_id]['radius_of_gyrations'][0], rg))
print("SMILES: {}".format(smileses[mol_id]))
print("WARNING: bonds provided by NGLViewer may be problematic")

mol = mda.Universe.empty(n_atoms=len(atomic_numbers))
mol.add_TopologyAttr('names', ["{}H{}".format(elements[n], hydrogen_numbers[i]) for i, n in enumerate(atomic_numbers)])
# mol.add_TopologyAttr('names', ["{}".format(elements[n]) for i, n in enumerate(atomic_numbers)])
# mol.load_new(trajectory - np.mean(trajectory, axis=1, keepdims=True)) ### view trajectories 
mol.load_new(structure) ### view structures
view = nv.show_mdanalysis(mol)
view

### Property Analysis

In [None]:
from rdkit.Chem.Descriptors import ExactMolWt
from rdkit.Chem.Lipinski import NumRotatableBonds
from rdkit.Chem.rdMolDescriptors import CalcNumRings
from rdkit.Chem.Lipinski import NumHAcceptors 
from rdkit.Chem.Lipinski import NumHDonors
from rdkit.Chem.Crippen import MolLogP
from rdkit.Chem.Descriptors import TPSA

prop_x = []
prop_y = []
for c, smi in zip(constituents, smileses):
    if smi == "": continue 
    try:
        prop_y.append(TPSA(Chem.MolFromSmiles(smi)))
        prop_x.append(c['TPSA'])
    except: continue

plt.scatter(tpsa_x, tpsa_y)