In [7]:
from typing import List, Tuple, Dict, Optional, Sequence
import pickle
import os
from tqdm import tqdm

import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import KekulizeException, MolToSmiles



# Helper functions

In [18]:
def split_dataset_by_smiles(mols: Sequence[Chem.Mol])\
        -> Dict[str, List[Chem.Mol]]:
    """
    Splits molecule sequence by smiles and transforms this Sequence
    into dictionary with smiles as keys and list of conformers
    for this smiles as values
    Args:
        mols: Sequence of Chem.Mol objects
    Returns:
        list of conformers for each smiles in dictionary
    """

    split = {}
    for mol in tqdm(mols):
        smiles = MolToSmiles(Chem.RemoveHs(mol))
        if smiles in split:
            split[smiles].append(mol)
        else:
            split[smiles] = [mol]
    return split


def dict_to_list_of_lists(splited_dict: Dict[str, List[Chem.Mol]])\
        -> List[List[Chem.Mol]]:
    """
    Transforms splited dict with molecule to list of lists
    Args:
        splited_dict: Dictionary with smiles as keys and list 
            of conformers for this smiles as values
    Returns:
        List of lists with conformers for each molecule
    """

    return [value for value in splited_dict.values()]


def load_pkl(inp: str) -> Optional[Sequence[Chem.Mol]]:
    return pickle.load(open(inp, 'rb'))


def load_sdf(inp: str) -> Optional[Sequence[Chem.Mol]]:
    return Chem.SDMolSupplier(inp, sanitize=True)


def load_data(inp: str) -> Optional[Sequence[Chem.Mol]]:
    file_type = inp.split('.')[-1]
    if file_type == 'pkl':
        return load_pkl(inp)
    elif file_type == 'sdf':
        return load_sdf(inp)
    print('Wrong file type: ' + file_type)
    return None

### read input data from sdf or pickle file:
`reference_mols = load_data(path_to_reference_data)`<br/>
`generated_mols = load_data(path_to_generated_data)`<br/>
<br/>
### split data by smiles:
`splited_reference_dict = split_dataset_by_smiles(reference_mols)`<br/>
`splited_generated_dict = split_dataset_by_smiles(generated_mols)`<br/>
<br/>


# ICRMSD

In [19]:
def _compute_rmsd(conf_a: Chem.Mol, conf_b: Chem.Mol) -> float:
    """
    Computes RMSD between two 3D molecules
    Args:
        conf_a: conformer of molecule
        conf_b: another conformer of molecule
    Returns:
        rmsd for two input conformers
    """

    conf_a = Chem.RemoveHs(conf_a)
    conf_b = Chem.RemoveHs(conf_b)
    try:
        return Chem.AllChem.AlignMol(
            conf_a, conf_b,
            atomMap=list(zip(range(conf_a.GetNumAtoms()),
                             range(conf_b.GetNumAtoms())))
        )
    except RuntimeError:
        print("RMSD AlignMol crashed on %s", Chem.MolToSmiles(conf_a))
        return 0


def _compute_distances(conf: Chem.Mol, confs: List[Chem.Mol])\
        -> np.ndarray:
    """
    Computes RMSD distances from a conformation to other conformations
    Args:
        conf: conformer of molecule
        confs: all conformers of molecule
    Returns:
        rmsd values between input conformer and all conformers
    """

    return np.array(
        [_compute_rmsd(conf, another_conf) for another_conf in confs]
    )


def get_pairwise_rmsd_statistics(all_mols: List[List[Chem.Mol]])\
        -> Dict[str, float]:
    """
    Returns statistics of pairwise rmsd values for all molecules
    Args:
        all_mols: List of lists with conformers for each molecule
    Returns:
        dict with mean, std and median values
    """

    all_pairwise = []
    for list_of_confs in tqdm(all_mols):
        pairwise_dist = np.array([_compute_distances(conf, list_of_confs)
                                 for conf in list_of_confs])
        all_pairwise.append(pairwise_dist.reshape((1, -1)))
    mean = np.mean(np.concatenate(all_pairwise, axis=1))
    std = np.std(np.concatenate(all_pairwise, axis=1))
    median = np.median(np.concatenate(all_pairwise, axis=1))
    stats = {'mean': mean, 'std': std, 'median': median}
    return stats

To compute ICRMSD for given dict with molecule conformations, first, transform dict to list of lists:<br/>
`generated_list_of_lists = dict_to_list_of_lists(splited_generated_dict)`<br/>
After that, run <br/>
`stats = get_pairwise_rmsd_statistics(generated_list_of_lists)`

# RED

In [20]:
def get_energy(mol: Chem.Mol) -> Optional[float]:
    """
    Returns MMFF94s energy value for given molecule
    Args:
        mol: molecule, for which energy value is computed
    Returns:
        energy value for given molecule or None, if energy
        can't been computed
    """

    try:
        prop = AllChem.MMFFGetMoleculeProperties(
            mol, mmffVariant='MMFF94s')
        ff = AllChem.MMFFGetMoleculeForceField(
            mol, prop)
        energy = ff.CalcEnergy()
        return energy
    except:
        return None


def _red_metric(reference_conf_list: List[Chem.Mol],
                generated_conf_list: List[Chem.Mol]) -> np.float:
    """
    Returns RED metric for given reference and generated conformers
    RED = MED(E_reference) - MED(E_generated)
    Args:
        reference_conf_list: list of reference conformers
        generated_conf_list: list of generated conformers
    Returns:
        red metric value
    """

    num_heavy_atoms = reference_conf_list[0].GetNumHeavyAtoms()
    reference_energies = [get_energy(Chem.AddHs(mol, addCoords=True))
                         for mol in reference_conf_list]
    generated_energies = [get_energy(Chem.AddHs(mol, addCoords=True))
                          for mol in generated_conf_list]
    reference_energies = [
        e for e in reference_energies if e is not None and e == e]
    generated_energies = [
        e for e in generated_energies if e is not None and e == e]
    return (np.nanmedian(generated_energies) -
            np.nanmedian(reference_energies)) / num_heavy_atoms


def get_red_metric_statistics(reference_mols_dict, generated_mols_dict)\
        -> Dict[str, np.float]:
    """
    Returns RED metric value statistics for given molecules
    Args:
        reference_mols_dict: Dictionary with smiles as keys
            and list of reference conformers for this smiles as values
        generated_mols_dict: Dictionary with smiles as keys
            and list of generated conformers for this smiles as values
    Returns:
        dictionary with RED metric statistics
    """

    intersect_smiles = set(generated_mols_dict.keys()).intersection(set(reference_mols_dict.keys()))
    print(f'generated data contains {len(generated_mols_dict)} smiles '
          f'reference data contains {len(reference_mols_dict)} smiles '
          f'intersection is {len(intersect_smiles)} smiles ')
    all_red_metrics = []
    for smiles in tqdm(list(intersect_smiles)):
        red = _red_metric(reference_mols_dict[smiles],
                          generated_mols_dict[smiles])
        all_red_metrics.append(red)
    all_red_metrics_filtred = [
        red for red in all_red_metrics if red == red] # filter nans
    stats = {'min': np.min(all_red_metrics_filtred),
             'max': np.max(all_red_metrics_filtred),
             'mean': np.mean(all_red_metrics_filtred),
             'std': np.std(all_red_metrics_filtred),
             'median': np.median(all_red_metrics_filtred)}
    return stats

To compute RED for given `splited_reference_dict` and `splited_generated_dict`, run<br/>
`stats = get_red_metric_statistics(splited_reference_dict, splited_generated_dict)`

# COV MAT

In [21]:
def calc_mat_cov(probe_mols: Dict[str, List[Chem.Mol]],
                 ref_mols: Dict[str, List[Chem.Mol]],
                 threshold: float = 1.25) -> Dict[str, np.float]:
    """
    Computes cov and mat metrics statistics for given genereted and
    reference conformers for each molecule.
    Args:
        probe_mols: Dictionary with smiles as keys
            and list of generated conformers for this smiles as values
        ref_mols: Dictionary with smiles as keys
            and list of reference conformers for this smiles as values
        threshold: treshold value for computing cov metric
    Returns:
        dictionary with median and mean of cov and mat metrics
    """

    intersect_smiles = set(probe_mols.keys()).intersection(set(ref_mols.keys()))
    print(f'generated data contains {len(probe_mols)} smiles '
          f'reference data contains {len(ref_mols)} smiles '
          f'intersection is {len(intersect_smiles)} smiles ')
    tot_confs = 0
    tot_cov = []
    tot_mat = []
    smiles_stats = dict()
    print('start main cycle')
    for smiles in tqdm(intersect_smiles):
        smiles_stats[smiles] = {'confs': []}
        smiles_confs = 0
        smiles_cov = 0
        smiles_mat = 0
        for conf in ref_mols[smiles]:
            smiles_confs += 1
            best_rmsd = np.inf
            best_conf = None
            for probe in probe_mols[smiles]:
                try:
                    rmsd = _compute_rmsd(probe, conf)
                except KekulizeException:
                    continue
                if rmsd < best_rmsd:
                    best_rmsd = rmsd
                    best_conf = probe

            smiles_mat += best_rmsd
            if best_rmsd < threshold:
                smiles_cov += 1
            # (reference conformation, best generated conformation, best rmsd, cov)
            smiles_stats[smiles]['confs'].append((conf,
                                                  best_conf,
                                                  best_rmsd,
                                                  best_rmsd < threshold))

        smiles_stats[smiles]['cov'] = smiles_cov / smiles_confs
        smiles_stats[smiles]['mat'] = smiles_mat / smiles_confs
        tot_confs += smiles_confs
        tot_cov.append(smiles_cov / smiles_confs)
        tot_mat.append(smiles_mat / smiles_confs)
    tot_mat = [mat for mat in tot_mat if mat != np.inf]
    tot_cov = [cov for cov in tot_cov if cov != np.inf]

    tot_stats = {
        'mean_cov': np.mean(tot_cov),
        'median_cov': np.median(tot_cov),
        'mean_mat': np.mean(tot_mat),
        'median_mat': np.median(tot_mat)
    }
    return tot_stats, smiles_stats

To compute cov mat statistics for given `splited_reference_dict`, `splited_generated_dict`<br/>
and `threshold`, run<br/>
`stats = calc_mat_cov(splited_generated_dict, splited_reference_dict, threshold=threshold)`