In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import CDPL.Chem as CDPLChem
from source.utils.conforge_conformer_generation import *
from source.utils.mol_utils import drop_disconnected_components, preprocess_mol
from rdkit import Chem as rdChem
from source.utils.npz_utils import get_field_from_npzs

from source.utils.conforge_conformer_generation import generate_conformers
from source.utils.mol2pyg import mols2pyg_list_with_targets
import torch
from source.utils.npz_utils import save_pyg_as_npz
import shutil
from pathlib import Path


def preprocess_smile(s):
    s = drop_disconnected_components(s)
    m = rdChem.MolFromSmiles(s)
    m = preprocess_mol(m)
    return rdChem.MolToSmiles(m)

def smi_to_cdpl_mol(s):
    s = drop_disconnected_components(s)
    m = rdChem.MolFromSmiles(s)
    m = preprocess_mol(m)
    s = rdChem.MolToSmiles(m)
    return  CDPLChem.parseSMILES(s)

def getSettings(minRMSD):
    settings = ConfGen.ConformerGenerator()
    settings.settings.setSamplingMode(1) # AUTO = 0; SYSTEMATIC = 1; STOCHASTIC = 2;
    settings.settings.timeout = 36000 #! this is core
    settings.settings.minRMSD = minRMSD # for freesolv test #! this is core ; 2.5
    print(f'Using minRMSD = {settings.settings.minRMSD}')
    settings.settings.energyWindow = 150000.0
    settings.settings.setMaxNumOutputConformers(100) # mostly irrelevant?
    return settings

def eval_confs(npz_path, minRMSD):
    tot = 0
    out = get_field_from_npzs(npz_path) #! get only smiles?
    print(f"Num of mols pre-data agumentation: ", len(out))
    n_confs = []
    settings = getSettings(minRMSD)
    for npz in out:
        s = str(npz['smiles'])
        mol = smi_to_cdpl_mol(s)
        status, num_confs = generateConformationEnsembles(mol, settings)
        try:
            status_to_str[status]
        except:
            # print(f"num_confs {num_confs}")
            n_confs.append(num_confs)
            tot += num_confs

    print(f"Num of mols post-data agumentation: ", tot)
    return np.bincount(np.array(n_confs)),


# tests

In [None]:
save_dir = Path('/storage_common/nobilm/pretrain_paper/guacamol/EXPERIMENTS/freesolv/train_backup')
eval_confs(save_dir)

In [None]:
eval_confs('/storage_common/nobilm/pretrain_paper/guacamol/EXPERIMENTS/freesolv/train_agumented')
# array([  0, 505,   4]) # 36 minutes
# array([  0, 479,  28,   4,   1,   1]) # 15.3 seconds! Consistent result


In [None]:
# do the same with a subset of qm9:
eval_confs('/storage_common/nobilm/pretrain_paper/guacamol/EXPERIMENTS/qm9ftTEST/train_data_dbg')
# array([   0, 4988,   10])

In [None]:
# small subset of guacamol
eval_confs('/storage_common/nobilm/pretrain_paper/guacamol/EXPERIMENTS/100k_random_split_first_test/dbg_train')

# array([ 0, 83, 70, 23, 19,  8, 14,  6,  4,  5,  6,  3,  2,  0,  4,  0,  2,
#         1,  1,  0,  1,  1,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,  0,
#         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
#         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
#         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
#         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
#         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
#         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
#         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
#         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
#         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
#         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
#         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1])

# SAVE AGUMENTED CONFORMERS

In [None]:
starting_path = '/storage_common/nobilm/pretrain_paper/guacamol/EXPERIMENTS/bace/val'
starting_path = Path(starting_path)
# eval_confs(starting_path, 2.0) # from 3354 to 11.2k

# Using minRMSD = 2.5
# (array([  0, 479,  28,   4,   1,   1]), 556)

# Num of mols pre-data agumentation:  513
# Using minRMSD = 2.0
# Num of mols post-data agumentation:  614

In [None]:
def getSettingsLOCAL(minRMSD):
    settings = ConfGen.ConformerGenerator()
    settings.settings.setSamplingMode(1) # AUTO = 0; SYSTEMATIC = 1; STOCHASTIC = 2;
    settings.settings.timeout = 36000 #! this is core
    settings.settings.minRMSD = minRMSD # for freesolv test #! this is core ; 2.5
    print(f'Using minRMSD = {settings.settings.minRMSD}')
    settings.settings.energyWindow = 150000.0
    settings.settings.setMaxNumOutputConformers(10) # mostly irrelevant? -n in https://cdpkit.org/v1.1.1/cdpl_python_cookbook/confgen/gen_ensemble.html
    return settings

save_dir = '/storage_common/nobilm/pretrain_paper/guacamol/EXPERIMENTS/bace/val_agumented_10confs_1.5RMSD'
os.makedirs(save_dir)
save_dir = Path(save_dir) # the folder must exist

out = get_field_from_npzs(starting_path)
n_confs = []
tot = 0
settings = getSettingsLOCAL(minRMSD = 1.5)
for npz_id, npz in enumerate(out):

    s = str(npz['smiles'])
    y = float(npz['graph_labels']) #!

    s = preprocess_smile(s)

    conformers = generate_conformers(s, settings)

    n_confs.append(len(conformers))
    tot+=len(conformers)

    if not conformers:
        shutil.copy(npz.zip.filename, save_dir / Path(npz.zip.filename).name)
        continue

    pyg_mol_fixed_fields = mols2pyg_list_with_targets([conformers[0]], [s], [y])[0]

    batched_pos = []
    for mol in conformers: #! here eventually write how many confs to keep for given mol
        pos = []
        conf = mol.GetConformer()
        for i, atom in enumerate(mol.GetAtoms()):
            positions = conf.GetAtomPosition(i)
            pos.append((positions.x, positions.y, positions.z))
        pos = torch.tensor(pos, dtype=torch.float32)
        batched_pos.append(pos)

    batched_pos = torch.stack(batched_pos)
    pyg_mol_fixed_fields.pos = batched_pos

    save_pyg_as_npz(pyg_mol_fixed_fields, f'{save_dir}/mol_{npz_id}')

np.bincount(n_confs),tot

In [None]:
# settings = getSettings(2.0)
# (array([  0, 449,  50,   6,   2,   2,   1,   1,   2]), 614)

# rmsd 1.5, max 10 confs
# (array([ 36,  37,  52,  22,  70,  34,  27,  36,  40,  37, 819]), 9914)

