In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
import sys 
sys.path.append(".")
import pickle as pkl 
import numpy as np
from tqdm import tqdm

## Inference

### Preparation (nets, constants and params)

In [2]:
import jax 
import jax.numpy as jnp
from functools import partial

#### nets
from cybertron.readout.naive_gfn import NaiveGraphFieldNetwork
from jax.sharding import PositionalSharding

In [3]:
#### setup environments (single-device or multiple-devices)

NDEVICES = 1
SHARDING = True #### you can use multiple devices
if SHARDING:
    NDEVICES = len(jax.devices())
    print("{} DEVICES detected: {}".format(NDEVICES, jax.devices()))

rng_key = jax.random.PRNGKey(8888) #### set your random seed here
np.random.seed(7777)

8 DEVICES detected: [cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3), cuda(id=4), cuda(id=5), cuda(id=6), cuda(id=7)]


In [4]:
#### constants
constituent_types = ['', 'O3H2', 'N2H2', 'CarH1', 'N2H1', 'NamH0', 'Oco2H0', 'N3H3', 'N4H3', 'C1H0', 'N2H0', 'Npl3H0', 'NarH1', 'O3H1', 'C2H0', 'O3H0', 'N4H1', 'CcatH0', 'N4H2', 'CarH0', 'N3H2', 'NamH2', 'C3H1', 'Npl3H2', 'NamH1', 'C2H2', 'N3H1', 'C3H3', 'C1H1', 'C3H2', 'N1H0', 'N3H0', 'FH0', 'C2H1', 'C3H0', 'C3H4', 'NarH0', 'O2H0', 'Npl3H3', 'Npl3H1', 'O2H1']
constituent_types.sort()
max_num_atoms = 9
num_experts = 2

#### setup & load trained models
cutoffs = [10.0, 15.0]
noise_thresholds = [0.5]
arg_dicts = {
    "num_atom_types": len(constituent_types), 
    "dim_atom_feature": 128, 
    "dim_edge_feature": 128, 
    "dim_atom_filter": 128, 
    "num_rbf_basis": 128, 
    "n_interactions": 6, 
}
nets = [NaiveGraphFieldNetwork(**arg_dicts, cutoff=c) for c in cutoffs]

param_paths = ["./params/naive_gfn_params/naive_gfn_track_1_jax.pkl", "./params/naive_gfn_params/naive_gfn_track_2_jax.pkl"]
params = []
for path in param_paths:
    with open(path, 'rb') as f: 
        params.append(pkl.load(f))
if SHARDING:
    ##### replicate params
    global_sharding = PositionalSharding(jax.devices()).reshape(NDEVICES, 1)
    params = jax.device_put(params, global_sharding.replicate())


score_fns = [partial(net.apply, p) for net, p in zip(nets, params)]

### Unconditional Generation

In [5]:
NSAMPLE_PER_DEVICE = 128
NSAMPLES = NSAMPLE_PER_DEVICE * NDEVICES
NATOMS = max_num_atoms

#### jit and vmap functions (mixture of experts)
def score_forward_fn(x, atom_type, sigma):
    cond_list = [sigma < noise_thresholds[0],] + \
                [jnp.logical_and(sigma >= noise_thresholds[i], sigma < noise_thresholds[i+1]) for i in range(0, len(noise_thresholds) - 1)] + \
                [sigma >= noise_thresholds[-1],]
    value_list = [fn(x, atom_type) for fn in score_fns]
    
    return jnp.sum(jnp.array(cond_list, dtype=jnp.float32)[..., None, None] * \
                    jnp.array(value_list, jnp.float32), axis=0)
    
#### Langevin dynamics iteration
def Langevin_one_step_fn(x, atom_type, rng_key, sigma, alpha):
    dx = score_forward_fn(x, atom_type, sigma)
    rng_key, normal_key = jax.random.split(rng_key)
    z = jax.random.normal(normal_key, shape=x.shape, dtype=jnp.float32)
    x = x + jnp.sqrt(2 * alpha) * z - alpha * dx / sigma
    return x, rng_key

Langevin_one_step_fn_jvj = jax.jit(jax.vmap(jax.jit(Langevin_one_step_fn),
                                            in_axes=(0,0,0,None,None)))

#### Consituents sampling s~p(s)

##### You can sample atom constituents from datasets

In [6]:
constituents = np.load("./dataset/qm9_constituents.npy")
index = np.random.choice(np.arange(constituents.shape[0]), NSAMPLES)
constituents = constituents[index]

print("Example constituents: ")
print("\t{}".format(" ".join([constituent_types[i] for i in constituents[0] if i > 0])))

Example constituents: 
	C3H3 C1H0 C1H0 C3H2 C3H1 C3H2 C3H1 O3H0 C3H1


##### You can sample atom constituents using an autoregressive model

In [7]:
from snail.snail import SNAIL

# model 
constituent_model = SNAIL(len(constituent_types), n_res_layers=5, n_attn_layers=12)
with open("./params/snail_params/snail_jax.pkl", 'rb') as f:
    constituent_model_params = pkl.load(f)
sample_fn = jax.jit(jax.vmap(jax.jit(partial(constituent_model.apply, 
                                             constituent_model_params))))

# sampling
x = jnp.zeros((NSAMPLES, max_num_atoms), dtype=jnp.float32)
out = jnp.zeros((NSAMPLES, max_num_atoms), dtype=jnp.int32)
if SHARDING:
    global_sharding = PositionalSharding(jax.devices()).reshape(NDEVICES, 1)
    x = jax.device_put(x, global_sharding.replicate())
    out = jax.device_put(out, global_sharding.replicate())

for atom in tqdm(range(max_num_atoms)):
    logits = sample_fn(x)
    sample_key, rng_key = jax.random.split(rng_key)
    sampled_c = jax.random.categorical(sample_key, logits[:, atom, :], axis=-1)

    out = out.at[:, atom].set(sampled_c)
    x = x.at[:, atom].set(sampled_c.astype(jnp.float32) / (len(constituent_types) - 1) * 2 - 1)
    
constituents = np.array(out)
print("Example constituents: ")
print("\t{}".format(" ".join([constituent_types[i] for i in constituents[0] if i > 0])))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:08<00:00,  1.10it/s]

Example constituents: 
	N3H0 C3H2 C3H0 C2H1 C3H2 C3H2 C3H2 O3H0 O2H0





In [8]:
#### or... you can design yourself!
#### remember setting constituents of padding atoms to 0

#### Structures sampling x~p(x|s)

In [9]:
n_steps, n_eq_steps = 1000, 10 
sigma_min, sigma_max = 0.01, 5.0
noise_scales = \
    np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), n_steps))

def Langevin_inference(x_t, atom_type, rng_keys, save_traj=False):
    epsilon = 2e-4
    trajectory = []
    for t in tqdm(range(n_steps)):
        sigma_t = noise_scales[n_steps-t-1]
        alpha = epsilon * sigma_t * sigma_t / (sigma_min * sigma_min)
        for k in range(n_eq_steps):
            x_t, rng_keys = Langevin_one_step_fn_jvj(x_t, atom_type, 
                                                     rng_keys, sigma_t, alpha)
            if save_traj: trajectory.append(x_t)
            
    dx = jax.vmap(score_forward_fn, in_axes=(0, 0, None))(x_t, atom_type, sigma_min)
    x_t = x_t - sigma_min * dx
    if save_traj: trajectory.append(x_t)

    return x_t, trajectory, rng_keys

In [10]:
### prepare rng keys for sampling
split_rng_keys = jax.random.split(rng_key, NSAMPLES +1)
rng_keys = split_rng_keys[:NSAMPLES]
rng_key = split_rng_keys[-1]
rng_key, normal_key = jax.random.split(rng_key)
x_t = jax.random.normal(normal_key, shape=(NSAMPLES,9,3), dtype=jnp.float32)

if SHARDING: 
    global_sharding = PositionalSharding(jax.devices()).reshape(-1, 1)
    constituents = jax.device_put(constituents, global_sharding.replicate())
    rng_keys = jax.device_put(rng_keys, global_sharding.replicate())
    x_t = jax.device_put(x_t, global_sharding.replicate())

In [11]:
structures, trajectories, rng_keys = Langevin_inference(x_t, constituents, rng_keys)
structures, trajectories = jax.tree_map(np.array, (structures, trajectories))

#### save results 
with open(f'results/results_0.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'constituents': constituents,
                           'trajectories': trajectories, 'structures': structures}), f)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:43<00:00,  9.70it/s]


#### View structures | trajectories

In [12]:
#### view trajectories | structures
import MDAnalysis as mda 
import nglview as nv 

#### load your results 
with open(f'results/results_0.pkl', 'rb') as f: 
    results = pkl.load(f)
    constituents = results['constituents']
    trajectories = results['trajectories']
    structures = results['structures']

mol_id = 1
constituent = constituents[mol_id]
n_atoms = np.sum(constituent > 0)
constituent = constituent[:n_atoms]
trajectory = np.array(trajectories)[:, mol_id, :n_atoms, :]
trajectory = trajectory - np.mean(trajectory, axis=1, keepdims=True)
structure = np.array(structures)[mol_id, :n_atoms, :]
structure = structure - np.mean(structure, axis=0, keepdims=True)
rg = np.sqrt(np.sum(structure ** 2) / n_atoms)
print("This is a molecule with {} atoms, rg = {:.2f} ang".format(n_atoms, rg))
print("WARNING: bonds provided by NGLViewer may be problematic")

mol = mda.Universe.empty(n_atoms=n_atoms)
mol.add_TopologyAttr('names', [constituent_types[i] for i in constituent])
# mol.load_new(trajectory - np.mean(trajectory, axis=1, keepdims=True)) ### view trajectories 
mol.load_new(structure) ### view structures
view = nv.show_mdanalysis(mol)
view



This is a molecule with 9 atoms, rg = 1.92 ang


NGLWidget()

#### Graph assembling

In [13]:
import Xponge
from graph_assembler.graph_assembler import assemble_mol_graph

element_table = {
    'C': 6, 'N': 7, 'O': 8, 'F': 9
}

success_or_not = []
smileses = []
for i, (constituent, structure) in tqdm(enumerate(zip(constituents, structures))):
    constituent_str = [constituent_types[c] for c in constituent if c > 0]
    atomic_numbers = [element_table[x[0]] for x in constituent_str]
    hydrogen_numbers = [int(x[-1]) for x in constituent_str]
    
    success, Xponge_mol, smiles = assemble_mol_graph(atomic_numbers, hydrogen_numbers, structure)
    success_or_not.append(success) 
    smileses.append("" if not success else smiles)
    
    #### export to mol2
    if success:
        ##### delete Hs (Hs are added to help recogonizing topology, their coordinates are fake)
        atoms = Xponge_mol.atoms[::1]
        hydrogen_atom_idx = np.sort([idx for idx, atom in enumerate(atoms) if 'H' in atom])[::-1]
        for atom_idx in hydrogen_atom_idx: 
            Xponge_mol.delete_atom(atom_idx)
        Xponge_mol.save_as_mol2('results/mol2/{}.mol2'.format(i), atomtype=None)
        
with open('results/result.smi', 'w') as f:
    for smiles in smileses:
        f.write("{}\n".format(smiles))
        
print(".mol2 files are saved in results/mol2")
print("smiles are saved in results/result.smi")
print("valid: {:.2f}, unique and valid: {:.2f}".format(np.sum(success_or_not) / NSAMPLES, 
                                                       len(np.unique(smileses)) / NSAMPLES))

1024it [00:08, 122.21it/s]

.mol2 files are saved in results/mol2
smiles are saved in results/result.smi
valid: 0.99, unique and valid: 0.98



