In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
import sys 
sys.path.append("../")
import math
import pickle as pkl 
import numpy as np
from tqdm import tqdm

print("We are now running Python in: ", sys.path)

## Preparation (nets, constants, params and functional utils)

In [None]:
import jax 
import jax.numpy as jnp

from functools import partial
from cybertron.common.config_load import load_config

#### load nets
from train.train import MolEditScoreNet
from cybertron.model.molct_plus import MolCT_Plus
from cybertron.readout import GFNReadout

from train.utils import set_dropout_rate_config
from jax.sharding import PositionalSharding

def _sharding(input, shards):

    n_device = shards.shape[0]
    if isinstance(input, (np.ndarray, jax.Array)):
        _shape = [n_device, ] + [1 for _ in range(input.ndim - 1)]
        return jax.device_put(input, shards.reshape(_shape))
    elif input is None:
        return jax.device_put(input, shards)
    else:
        raise TypeError(f"Invalid input: {input}")

from inference.inference import DPM_3_inference, Langevin_inference, DPM_pp_2S_inference

In [None]:
NDEVICES = 1
NATOMS = 64
SHARDING = True #### you can use multiple devices
if SHARDING:
    NDEVICES = len(jax.devices())
    print("{} DEVICES detected: {}".format(NDEVICES, jax.devices()))

def split_rngs(rng_key, shape):
    size = np.prod(shape)
    rng_keys = jax.random.split(rng_key, size + 1)
    return rng_keys[:-1].reshape(shape + (-1,)), rng_keys[-1]

rng_key = jax.random.PRNGKey(8888) #### set your random seed here
np.random.seed(7777)

In [None]:
##### initialize models (structure diffusion model)
encoder_config = load_config("../config/molct_plus.yaml")
gfn_config = load_config("../config/gfn.yaml")
gfn_config.settings.n_interactions = 4

modules = {
    "encoder": {"module": MolCT_Plus, 
                "args": {"config": encoder_config}},
    "gfn": {"module": GFNReadout, 
            "args": {"config": gfn_config}}
}

##### load params
load_ckpt_paths = ['../params/ZINC_3m/structure_model/moledit_params_track1.pkl', 
                   '../params/ZINC_3m/structure_model/moledit_params_track2.pkl',
                   '../params/ZINC_3m/structure_model/moledit_params_track3.pkl'] 
noise_thresholds = [0.35, 1.95]

params = []
for path in load_ckpt_paths:
    with open(path, 'rb') as f: 
        params.append(pkl.load(f))
    
if SHARDING:
    ##### replicate params
    global_sharding = PositionalSharding(jax.devices()).reshape(NDEVICES, 1)
    params = jax.device_put(params, global_sharding.replicate())

for k, v in modules.items():
    modules[k]['args']['config'] = \
        set_dropout_rate_config(modules[k]['args']['config'], 0.0)
    modules[k]["module"] = v["module"](**v["args"])
    modules[k]["callable_fn"] = [] 
    for param in params:
        partial_params = {"params": param["params"]['score_net'].pop(k)}
        modules[k]["callable_fn"].append(partial(modules[k]["module"].apply, partial_params))

moledit_scorenets = [MolEditScoreNet(
        encoder=modules['encoder']['callable_fn'][k],
        gfn=modules['gfn']['callable_fn'][k],
    ) for k in range(len(load_ckpt_paths))]

In [None]:
##### initialize models (constituents model)
from config.transformer_config import transformer_config
from transformer.model import Transformer, TransformerConfig

with open('../params/ZINC_3m/constituents_model/constituents_vocab.pkl', 'rb') as f:
    constituent_vocab_list = pkl.load(f)
NCONSTITUENTS = len(constituent_vocab_list) # 38
NRG_TOKENS = 3 # seq_len = 38 + 3
SEQ_LEN = NCONSTITUENTS + NRG_TOKENS

NRG_VOCABS = 11
transformer_config.deterministic = True
transformer_config.dtype = jnp.float32
transformer = Transformer(
    config=TransformerConfig(
            **{
                **transformer_config,
                "vocab_size": NATOMS + NRG_VOCABS + 1, 
                "output_vocab_size": NATOMS + NRG_VOCABS + 1}, )
)

##### load params
with open("../params/ZINC_3m/constituents_model/moledit_params.pkl", "rb") as f:
    params = pkl.load(f)
    params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)

if SHARDING:
    ##### replicate params
    global_sharding = PositionalSharding(jax.devices()).reshape(NDEVICES, 1)
    params = jax.device_put(params, global_sharding.replicate())

def top_p_sampling(logits, rng_key, p=0.9):
    sorted_indices = jnp.argsort(logits)
    sorted_logits = logits[sorted_indices]
    sorted_probs = jax.nn.softmax(sorted_logits)
    cum_probs = jnp.cumsum(sorted_probs)
    invalid_mask = cum_probs < (1-p)
    
    rng_key, sample_key = jax.random.split(rng_key)
    sampled_token = jax.random.categorical(sample_key, sorted_logits+invalid_mask.astype(jnp.float32)*(-1e5))
    
    return sorted_indices[sampled_token], rng_key 
        
##### prepare functions, jit & vmap
jitted_logits_fn = jax.jit(transformer.apply)
top_p_sampling_fn = jax.vmap(jax.vmap(jax.jit(partial(top_p_sampling, p=0.9))))

In [None]:
from rdkit import Chem

def SMILES_to_constituents(smi):
    mol = Chem.MolFromSmiles(smi)
    mol = Chem.RemoveAllHs(mol)
    atomic_numbers = np.array([atom.GetAtomicNum() for atom in mol.GetAtoms()], dtype=np.uint8)
    hydrogen_numbers = np.array([atom.GetTotalNumHs() for atom in mol.GetAtoms()], dtype=np.uint8)
    hybridizations = np.array([atom.GetHybridization() for atom in mol.GetAtoms()], dtype=np.uint8)

    bond_ids = np.array([(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()], dtype=np.uint8)
    bond_types = np.array([int(bond.GetBondType()) for bond in mol.GetBonds()], dtype=np.uint8)
    
    topology = {i: {} for i in range(len(atomic_numbers))}
    for (atom_i, atom_j), bond_type in zip(bond_ids, bond_types):
        topology[atom_i][atom_j] = topology[atom_j][atom_i] = bond_type
    
    constituents_dict = {
        "atomic_numbers": atomic_numbers,
        "hydrogen_numbers": hydrogen_numbers,
        "hybridizations": hybridizations,
        "bonds": topology,
    }

    return constituents_dict

def RDMol_to_constituents_and_structure(mol):
    mol = Chem.RemoveAllHs(mol)
    atomic_numbers = np.array([atom.GetAtomicNum() for atom in mol.GetAtoms()], dtype=np.uint8)
    hydrogen_numbers = np.array([atom.GetTotalNumHs() for atom in mol.GetAtoms()], dtype=np.uint8)
    hybridizations = np.array([atom.GetHybridization() for atom in mol.GetAtoms()], dtype=np.uint8)

    bond_ids = np.array([(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()], dtype=np.uint8)
    bond_types = np.array([int(bond.GetBondType()) for bond in mol.GetBonds()], dtype=np.uint8)
    
    topology = {i: {} for i in range(len(atomic_numbers))}
    for (atom_i, atom_j), bond_type in zip(bond_ids, bond_types):
        topology[atom_i][atom_j] = topology[atom_j][atom_i] = bond_type
    
    constituents_dict = {
        "atomic_numbers": atomic_numbers,
        "hydrogen_numbers": hydrogen_numbers,
        "hybridizations": hybridizations,
        "bonds": topology,
    }

    for c in mol.GetConformers():
        structure = np.array(c.GetPositions())

    return constituents_dict, structure

def mol_with_index(mol):
    atoms = mol.GetNumAtoms()
    for idx in range(atoms):
        mol.GetAtomWithIdx(idx).SetProp('molAtomMapNumber', str(mol.GetAtomWithIdx(idx).GetIdx()))

    return mol

## Structure Rendring & Sampling of Textual Representations (SMILES)

In [None]:
NSAMPLE_PER_DEVICE = 16 # 1 / 8
NSTRUCTURE_PER_SAMPLE = 8 # 1024 is used in evaluation
NSAMPLES = 128 # 1024 is used in evaluation
NSAMPLE_PER_BATCH = int(NSAMPLE_PER_DEVICE * NDEVICES)
NSTRUCTURE_PER_BATCH = NSTRUCTURE_PER_SAMPLE * NSAMPLE_PER_BATCH
NBATCHES = NSAMPLES // NSAMPLE_PER_BATCH
NSTRUCTURES = NSAMPLES * NSTRUCTURE_PER_SAMPLE
INFERENCE_METHOD = "DPM_3"

print("NSAMPLE_PER_DEVICE: {}".format(NSAMPLE_PER_DEVICE))
print("NSAMPLES: {}".format(NSAMPLES))
print("NSAMPLE_PER_BATCH: {}".format(NSAMPLE_PER_BATCH))
print("NBATCHES: {}".format(NBATCHES))

#### jit and vmap functions
def score_forward_fn(atom_feat, bond_feat, x, atom_mask, sigma, rg, gamma=1.0):
    cond_list = [sigma < noise_thresholds[0],] + \
                [jnp.logical_and(sigma >= noise_thresholds[i], sigma < noise_thresholds[i+1]) for i in range(0, len(noise_thresholds) - 1)] + \
                [sigma >= noise_thresholds[-1],]
    value_list = [net.apply(
                    {}, atom_feat, bond_feat, x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value_unc_list = [net.apply(
                    {}, atom_feat, jnp.zeros_like(bond_feat), x, atom_mask, sigma, rg)[-1] for net in moledit_scorenets]
    value = gamma * jnp.array(value_list, jnp.float32) +\
                (1.0 - gamma) * jnp.array(value_unc_list, jnp.float32)
    
    return jnp.sum(jnp.array(cond_list, dtype=jnp.float32)[..., None, None] * value, axis=0)

score_forward_fn_jvj = jax.jit(jax.vmap(jax.jit(score_forward_fn)))
if INFERENCE_METHOD == "DPM_3":
    inference_fn = partial(DPM_3_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "DPM_pp_2S":
    inference_fn = partial(DPM_pp_2S_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=20, shard_inputs=SHARDING)
elif INFERENCE_METHOD == "Langevin":
    inference_fn = partial(Langevin_inference, score_fn=score_forward_fn_jvj, 
                           n_steps=1000, shard_inputs=SHARDING)

### Test SMILES

In [None]:
with open('../moledit_dataset/smileses/geodiff_test_smileses.pkl', 'rb') as f:
    test_smileses = [str(x).strip() for x in pkl.load(f)[:NSAMPLES]]

### Constituents Sampling

In [None]:
constituents_dicts = {}
constituents_arrs = []

### preprocess smiles
for i, smi in tqdm(enumerate(test_smileses)):
    constituents_dict = SMILES_to_constituents(smi)
    constituents_dicts[smi] = constituents_dict

    #### sample rg 
    constituents_str = np.array(["{}_{}_{}".format(i,j,k) for i,j,k in zip(constituents_dict['atomic_numbers'], 
                                                                           constituents_dict['hydrogen_numbers'], 
                                                                           constituents_dict['hybridizations'])])
    constituents_arrs.append(np.array([np.sum(constituents_str == v) for v in constituent_vocab_list]))
constituents_arrs = jnp.array(constituents_arrs, dtype=jnp.int32)

### sample rgs
input_dict = {
    "inputs": jnp.ones((NSTRUCTURES, SEQ_LEN), dtype=jnp.int32), 
    "generation_result": jnp.ones((NSTRUCTURES, SEQ_LEN), dtype=jnp.int32)
}
input_dict["generation_result"] = input_dict["generation_result"].at[:, :NCONSTITUENTS].set(
                                            jnp.repeat(constituents_arrs, NSTRUCTURE_PER_SAMPLE, axis=0) + 1)
input_dict["inputs"] = input_dict["inputs"].at[:, NCONSTITUENTS:].set(NATOMS + NRG_VOCABS) #### unk token for rg

inv_temperature = 1.25
generation_results = []
rng_keys, rng_key = split_rngs(rng_key, (NBATCHES, ))
for b in tqdm(range(NBATCHES)):
    input_dict_ = jax.tree_map(lambda x:x[b*NSTRUCTURE_PER_BATCH:(b+1)*NSTRUCTURE_PER_BATCH], input_dict)
    rng_key_b = rng_keys[b]
    rng_keys_b, _ = split_rngs(rng_key_b, (NSTRUCTURE_PER_BATCH, SEQ_LEN))
    
    if SHARDING:
        #### shard inputs 
        ds_sharding = partial(_sharding, shards=global_sharding)
        input_dict_ = jax.tree_map(ds_sharding, input_dict_)
        rng_keys_b = ds_sharding(rng_keys_b)

    
    for step in range(NCONSTITUENTS, SEQ_LEN):
        logits = jitted_logits_fn(params, 
                                  input_dict_['inputs'],
                                  input_dict_['generation_result'])
        valid_logits_mask = jnp.zeros_like(logits, dtype=jnp.float32).at[..., -NRG_VOCABS:-1].set(1)
        logits += (-1e5) * (1.0 - valid_logits_mask)
        sampled_token, rng_keys_b = top_p_sampling_fn(logits * inv_temperature, rng_keys_b)
        input_dict_['generation_result'] = \
            input_dict_['generation_result'].at[..., step].set(sampled_token[..., step])
    generation_results.append(input_dict_['generation_result'])
    
generation_result = np.array(jnp.concatenate(generation_results, axis=0)) - 1
for i, (smi, seqs) in tqdm(enumerate(zip(test_smileses, generation_result.reshape(NSAMPLES,NSTRUCTURE_PER_SAMPLE,-1)))):    
    #### decode rg
    rg_seqs = seqs[:, -NRG_TOKENS:] - NATOMS
    rgs = []
    for rg_seq in rg_seqs:
        rg = np.exp(rg_seq[0]) * float("{}.{}".format(rg_seq[1], "".join([str(x) for x in rg_seq[2:]])))
        rgs.append(rg)
    constituents_dicts[smi].update({"radius_of_gyrations": np.array(rgs, dtype=np.float32)})

### Structure Sampling

In [None]:
print("Example constituents: ")
print("\tatomic numbers: ", constituents_dicts[test_smileses[0]]['atomic_numbers'])
print("\thydrogen numbers: ", constituents_dicts[test_smileses[0]]['hydrogen_numbers'])
print("\thybridizaions: ", constituents_dicts[test_smileses[0]]['hybridizations'])
print("\tradius of gyrations: ", constituents_dicts[test_smileses[0]]['radius_of_gyrations'])
print("\tbonds: ", constituents_dicts[test_smileses[0]]['bonds'])
print("\t**REMARK**: hybridization symbols are same with RDkit")

from inference.utils import preprocess_data

print("Preprocessing inputs")
input_dicts = []
for smi in tqdm(test_smileses):
    d = constituents_dicts[smi]
    input_dicts.extend([preprocess_data({**{k: v for k, v in d.items() if k != 'radius_of_gyrations'}, 
                                         "radius_of_gyrations": [d['radius_of_gyrations'][i]]}, NATOMS) for i in range(NSTRUCTURE_PER_SAMPLE)])
input_dict = {
    k: np.stack([d[k] for d in input_dicts]) for k in input_dicts[0].keys()
}

print("input shape & dtypes: ")
for k, v in input_dict.items():
    print("\t{} shape: {} dtype: {}".format(k, v.shape, v.dtype))

#### With DPM Solver

In [None]:
rng_keys, rng_key = split_rngs(rng_key, (NBATCHES,))
input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
structures, trajectories = [], []

#### JAX compiles a jitted function when you call it first time.
### so it will be slow when you run this block first time.
print("inference requires {} batches running on {} GPUs".format(NBATCHES, NDEVICES))
for b in range(NBATCHES):
    structures_, trajectories_, _ = inference_fn(
        jax.tree_map(lambda x:x[b*NSTRUCTURE_PER_BATCH:(b+1)*NSTRUCTURE_PER_BATCH], input_dict), rng_keys[b])
    structures.append(structures_)
    # trajectories.append(trajectories_) # if you want to save trajectories

In [None]:
structures = np.array(structures).reshape(NSAMPLES, NSTRUCTURE_PER_SAMPLE, NATOMS, 3)
trajectories = np.array(trajectories).transpose((1, 0, 2, 3, 4)).reshape(-1, NSAMPLES, NSTRUCTURE_PER_SAMPLE, NATOMS, 3) if len(trajectories) > 0 \
                else np.array(trajectories)
# structures, trajectories = [], []

#### save results 
with open(f'../results/structure_rendering/result_geodiff_test_smiles.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'smiles': test_smileses,
                           'constituents': [{k: constituents_dicts[smi][k] for k in constituents_dicts[smi].keys() if k != 'bonds'} 
                                            for smi in  test_smileses],
                           'conditional_infos': [{'bonds': constituents_dicts[smi]['bonds']} for smi in test_smileses],  
                           'trajectories': trajectories, 
                           'structures': structures}), f)

#### With FPS Solver

In [None]:
import numpy as np
import Xponge
from Xponge.helper import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Lipinski import RotatableBondSmarts
from rdkit.Geometry import Point3D
from rdkit.Chem import rdMolTransforms

def get_rotable_dihedrals(mol):
    bond_ids = np.array([(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()], dtype=np.uint8)
    bond_types = np.array([int(bond.GetBondType()) for bond in mol.GetBonds()], dtype=np.uint8)
    bonds = {i: {} for i in range(len(mol.GetAtoms()))}
    for ((atom_i, atom_j), bond_type) in zip(bond_ids, bond_types):
        bonds[atom_i][atom_j] = bonds[atom_j][atom_i] = bond_type

    rotable_bonds = mol.GetSubstructMatches(RotatableBondSmarts)
    rotable_bonds_dihedral_ids = []
    for (atom_j, atom_k) in rotable_bonds:
        for atom_i in bonds[atom_j].keys():
            if atom_i == atom_k: continue 
            for atom_l in bonds[atom_k].keys():
                if atom_l == atom_j: continue
                rotable_bonds_dihedral_ids.append([atom_i, atom_j, atom_k, atom_l])
    return np.array(rotable_bonds_dihedral_ids, dtype=np.int32)

mol = Chem.MolFromSmiles(test_smileses[0])
rotable_bonds_dihedral_ids = get_rotable_dihedrals(mol)
rotable_bonds_dihedral_ids = jnp.array(rotable_bonds_dihedral_ids)

In [None]:
fp_coeff = 2e3  ### This coeficient controls the additional Fokker-Planck-gradient term in SDE
rng_keys, rng_key = split_rngs(rng_key, (NBATCHES,))
input_dict = jax.tree_map(lambda x:jnp.array(x), input_dict)
structures, trajectories = [], []

#### JAX compiles a jitted function when you call it first time.
### so it will be slow when you run this block first time.
print("inference requires {} batches running on {} GPUs".format(NBATCHES, NDEVICES))
for b in range(NBATCHES):
    structures_, trajectories_, _ = inference_fn(
        jax.tree_map(lambda x:x[b*NSTRUCTURE_PER_BATCH:(b+1)*NSTRUCTURE_PER_BATCH], input_dict), rng_keys[b], 
        dihedral_dict={'dihedral_atom_ids': rotable_bonds_dihedral_ids, 'fp_coeff': fp_coeff})
    structures.append(structures_)
    # trajectories.append(trajectories_) # if you want to save trajectories

In [None]:
structures = np.array(structures).reshape(NSAMPLES, NSTRUCTURE_PER_SAMPLE, NATOMS, 3)
trajectories = np.array(trajectories).transpose((1, 0, 2, 3, 4)).reshape(-1, NSAMPLES, NSTRUCTURE_PER_SAMPLE, NATOMS, 3) if len(trajectories) > 0 \
                else np.array(trajectories)
# structures, trajectories = [], []

#### save results 
with open(f'../results/structure_rendering/result_geodiff_test_smiles_fps.pkl', 'wb') as f: 
    pkl.dump(jax.tree_map(np.array, 
                          {'smiles': test_smileses,
                           'constituents': [{k: constituents_dicts[smi][k] for k in constituents_dicts[smi].keys() if k != 'bonds'} 
                                            for smi in  test_smileses],
                           'conditional_infos': [{'bonds': constituents_dicts[smi]['bonds']} for smi in test_smileses],  
                           'trajectories': trajectories, 
                           'structures': structures}), f)

### View Structures

In [None]:
#### view trajectories | structures
import MDAnalysis as mda 
import nglview as nv 

#### load your results 
with open(f'../results/structure_rendering/result_geodiff_test_smiles.pkl', 'rb') as f:
# with open(f'results/structure_rendering/result_geodiff_test_smiles_fps.pkl', 'rb') as f:
    results = pkl.load(f)
    test_smileses = results['smiles']
    constituents = results['constituents']
    conditional_infos = results['conditional_infos']
    # trajectories = results['trajectories']
    structures = results['structures']

elements = {
    6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 
    35: 'Br', 53: 'I'
}

mol_id, structure_id = 0, 0
print("smiles: ", test_smileses[mol_id], len(test_smileses))
atomic_numbers = constituents[mol_id]['atomic_numbers']
hydrogen_numbers = constituents[mol_id]['hydrogen_numbers']
hybridizations = constituents[mol_id]['hybridizations']
print("atomic numbers:", ",".join([str(x) for x in atomic_numbers]))
print("hydrogen_numbers:", ",".join([str(x) for x in hydrogen_numbers]))
print("hybridizations:", ",".join([str(x) for x in hybridizations]))
n_atoms = len(atomic_numbers)
# trajectory = np.array(trajectories)[:, mol_id, structure_id, :n_atoms, :]
structure = np.array(structures)[mol_id, structure_id, :n_atoms, :]
rg = np.sqrt(np.sum(structure ** 2) / n_atoms)
print("This is a molecule with {} atoms, rg = {:.2f}/{:.2f} ang".format(n_atoms, rg, constituents[mol_id]['radius_of_gyrations'][structure_id]))
print("Conditional infos indicate the following bonds")
print("WARNING: bonds provided by NGLViewer may be problematic")

mol = mda.Universe.empty(n_atoms=len(atomic_numbers))
mol.add_TopologyAttr('names', ["{}{}".format(elements[n], i) for i, n in enumerate(atomic_numbers)])
# mol.add_TopologyAttr('names', ["{}".format(elements[n]) for i, n in enumerate(atomic_numbers)])
# mol.load_new(trajectory - np.mean(trajectory, axis=1, keepdims=True)) ### view trajectories 
mol.load_new(structure) ### view structures
view = nv.show_mdanalysis(mol)
view

### Graph Assembly

In [None]:
with open(f'../results/structure_rendering/result_geodiff_test_smiles.pkl', 'rb') as f:
# with open(f'results/structure_rendering/result_geodiff_test_smiles_fps.pkl', 'rb') as f:
    results_dict = pkl.load(f)

test_smileses = results_dict['smiles']
constituents = results_dict['constituents']
conditional_infos = results_dict['conditional_infos']
structures = results_dict['structures']

#### Evaluation on Molecular Physics Instability (MPI)

In [None]:
import Xponge
from Xponge.helper import rdkit as Xponge_rdkit_helper
from graph_assembler.graph_assembler import assemble_mol_graph, check_bonds, get_rotable_bonds, get_rotable_dihedrals, uff_eval
from functools import reduce
import warnings
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*') ### disable rdkit warning
from multiprocessing import Pool ### enabling multiprocess

def assemble_fn(arg_dict):
    mol_id = arg_dict['mol_id']
    atomic_numbers = arg_dict['atomic_numbers']
    hydrogen_numbers = arg_dict['hydrogen_numbers']
    mol_structures = arg_dict['structures']
    conditional_info = arg_dict['bonds']

    ret_dict = {
        'valid_or_not': np.zeros(len(mol_structures), dtype=np.bool_),
        'smileses': [], 
        'num_required_bonds': np.zeros(len(mol_structures), dtype=np.int32),
        'num_satisfied_bonds': np.zeros(len(mol_structures), dtype=np.int32),
        'rotable_bonds': [],
        'rotable_bonds_dihedrals': [],
        'uff_enes': [], 'uff_forces': [],
    }

    rdmols = []
    for structure_id, structure in enumerate(mol_structures):
        structure = structure[:len(atomic_numbers)]
        success, Xponge_mol, smiles = assemble_mol_graph(atomic_numbers, hydrogen_numbers, structure)
        ret_dict['valid_or_not'][structure_id] = success
        ret_dict['smileses'].append("" if not success else smiles)
    
        if success:
            ##### delete Hs (Hs are added to help recogonizing topology, their coordinates are fake)
            atoms = Xponge_mol.atoms[::1]
            hydrogen_atom_idx = np.sort([idx for idx, atom in enumerate(atoms) if 'H' in atom])[::-1]
            for atom_idx in hydrogen_atom_idx: 
                Xponge_mol.delete_atom(atom_idx)
                
            ### check bonds
            ret = check_bonds(Xponge_mol.bonds, conditional_info, allow_perm=False)
            ret_dict['num_required_bonds'][structure_id] = ret[0]
            ret_dict['num_satisfied_bonds'][structure_id] = ret[1]

            if ret[0] == ret[1]: 
                ### calculate dihedrals
                mol = Xponge_rdkit_helper.assign_to_rdmol(Xponge_mol)
                rotable_bonds = get_rotable_bonds(mol)
                ret_dict['rotable_bonds'].append(rotable_bonds)
                rdmols.append((mol, structure))

                ### uff evaluation (MPI)
                try:
                    ene, force, _, _, _ = uff_eval(mol, structure)
                    ret_dict['uff_enes'].append(ene)
                    ret_dict['uff_forces'].append(force)
                except: 
                    ret_dict['uff_enes'].append(None) 
                    ret_dict['uff_forces'].append(None)
                    
                    # warnings.warn('Failed to calculate UFF energy & force for mol {} structure {}'.format(mol_id, structure_id)) 
        else:
            ret_dict['num_required_bonds'][:] = -1
            ret_dict['num_satisfied_bonds'][:] = -1

    rotable_bonds = reduce(lambda x, y: x | y, [set(x) for x in ret_dict['rotable_bonds']], set(ret_dict['rotable_bonds'][0]))
    for (mol, structure) in rdmols:
        ret_dict['rotable_bonds_dihedrals'].append(
            get_rotable_dihedrals(mol, structure, given_bonds=conditional_info, given_rotable_bonds=rotable_bonds)
        )
        
    return ret_dict
    
arg_dicts = []
for mol_id, (atomic_numbers, hydrogen_numbers, mol_structures, conditional_info) in \
    tqdm(enumerate(zip([c['atomic_numbers'] for c in constituents],
                       [c['hydrogen_numbers'] for c in constituents],
                       structures, 
                       [c['bonds'] for c in conditional_infos]))):
    arg_dicts.append(
        {'mol_id': mol_id, 'atomic_numbers': atomic_numbers, 'hydrogen_numbers': hydrogen_numbers, 
         'structures': mol_structures, 'bonds': conditional_info}
    )

ret_dicts = []
##### parallel code
pool = Pool(processes=os.cpu_count())
for d in tqdm(pool.imap(func=assemble_fn, iterable=arg_dicts),
                   total = NSAMPLES):
    ret_dicts.append(d)
pool.close()
##### serial code 
# for arg_dict in tqdm(arg_dicts):
#     ret_dicts.append(assemble_fn(arg_dict))

valid_or_not = np.stack([d['valid_or_not'] for d in ret_dicts])
num_required_bonds = np.stack([d['num_required_bonds'] for d in ret_dicts])
num_satisfied_bonds = np.stack([d['num_satisfied_bonds'] for d in ret_dicts])
rotable_bonds = [d['rotable_bonds'] for d in ret_dicts]
rotable_bonds_dihedrals = [d['rotable_bonds_dihedrals'] for d in ret_dicts]
uff_enes = [d['uff_enes'] for d in ret_dicts]
uff_forces = [d['uff_forces'] for d in ret_dicts]
smileses = [d['smileses'] for d in ret_dicts]

print("Molecular Physics Instability: {:.2f} among {} structures, {} structures per molecule".format(
    np.mean([np.mean(np.linalg.norm(f, axis=-1)) for f in uff_forces]), 
    NSTRUCTURES, NSTRUCTURE_PER_SAMPLE,
))

In [None]:
### save analysis results
with open(f'../results/structure_rendering/analysis_result_geodiff_smiles_rdkit.pkl', 'wb') as f:
# with open(f'../results/structure_rendering/analysis_result_geodiff_smiles_fp_trick.pkl', 'wb') as f:
    pkl.dump(
        {'n_atoms': np.array([len(c['atomic_numbers']) for c in constituents], dtype=np.int32),
         'valid_or_not': valid_or_not, 
         'num_required_bonds': num_required_bonds, 
         'num_satisfied_bonds': num_satisfied_bonds,
         'rotable_bonds_dihedrals': rotable_bonds_dihedrals, 
         'uff_enes': uff_enes, 'uff_forces': uff_forces, }, f
    )

#### Evalulation on Conformational Diversity (LogNeff)

In [None]:
with open(f'../results/structure_rendering/analysis_result_geodiff_smiles_rdkit.pkl', 'rb') as f:
    analysis_result = pkl.load(f)
    n_atoms = analysis_result['n_atoms']
    valid_or_not = analysis_result['valid_or_not']
    num_required_bonds = analysis_result['num_required_bonds']
    num_satisfied_bonds = analysis_result['num_satisfied_bonds']
    rotable_bonds_dihedrals = analysis_result['rotable_bonds_dihedrals']
    uff_enes = analysis_result['uff_enes']
    uff_forces = analysis_result['uff_forces']

In [None]:
### Neff calculation 
### Neff = Sum((\sum_t^N I[D(s, t) < delta])^-1)
from functools import reduce
from multiprocessing import Pool ### enabling multiprocess

delta = 0.3
nstructures_per_sample_list = [8, ] # [8, 16, 32, 64, 128, 256, ..., 1024]
def neff_fn(rotable_bond_dihedral):
    rotable_bond_dihedral = [rd for rd in rotable_bond_dihedral if rd is not None]
    if len(rotable_bond_dihedral) <= 0: return None
    rotable_bonds = rotable_bond_dihedral[0].keys() # reduce(lambda x, y: x & y, [set(d.keys()) for d in rotable_bond_dihedral], set(rotable_bond_dihedral[0].keys()))
    if len(rotable_bonds) <= 0: return None
    dihedral_distances = []
    for bond in rotable_bonds:
        dihedrals = np.array([d[bond] for d in rotable_bond_dihedral], dtype=np.float32)
        num_dihedrals = len(dihedrals)
        dihedral_distance = [ 2 - 2 * np.cos(alpha - beta) for alpha in dihedrals for beta in dihedrals]
        dihedral_distance = np.array(dihedral_distance).reshape(num_dihedrals, num_dihedrals)
        dihedral_distances.append(dihedral_distance)
    dihedral_distances = np.array(dihedral_distances)
    neffs = []
    for b_size in nstructures_per_sample_list:
        structure_distance = np.max(dihedral_distances[:,:b_size, :b_size], axis=0)
        neffs.append(np.sum(1.0 / np.sum(structure_distance < delta, axis=-1)))

    return np.array(neffs)

pool = Pool(processes=os.cpu_count())
neff_arr = []
for neff in tqdm(pool.imap(func=neff_fn, iterable=rotable_bonds_dihedrals),
                           total=NSAMPLES):
    neff_arr.append(neff)
pool.close()

mean_neff_arr = np.mean([x for x in neff_arr if x is not None], axis=0)
for b_size, mean_neff in zip(nstructures_per_sample_list, mean_neff_arr):
    print("Mean sampling efficiency: Neff = {:.2f} in {} structures".format(mean_neff, b_size))