In [2]:
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, Crippen, Lipinski
from rdkit.Chem import QED
import numpy as np
from rdkit.Chem.rdMolTransforms import GetAngleDeg
from sascorer import compute_sa_score
from copy import deepcopy
from rdkit.Chem.rdMolTransforms import GetBondLength

def disable_rdkit_logging():
    """
    Disables RDKit whiny logging.
    """
    import rdkit.rdBase as rkrb
    import rdkit.RDLogger as rkl
    logger = rkl.logger()
    logger.setLevel(rkl.ERROR)
    rkrb.DisableLog('rdApp.error')

def cal_lipinski(mol):
    mol = deepcopy(mol)
    Chem.SanitizeMol(mol)
    rule_1 = Descriptors.ExactMolWt(mol) < 500
    rule_2 = Lipinski.NumHDonors(mol) <= 5
    rule_3 = Lipinski.NumHAcceptors(mol) <= 10
    rule_4 = (logp:=Crippen.MolLogP(mol)>=-2) & (logp<=5)
    rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
    return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])

def cal_qed(mol):
    return QED.qed(mol)

def cal_sa(mol):
    return compute_sa_score(mol)

def cal_logp(mol):
    return Crippen.MolLogP(mol)


In [3]:
def find_i_ring(mol, i):  # whether mol has ring with size i
    ssr = Chem.GetSymmSSSR(mol)
    for ring in ssr:
        if len(ring) == i:
            return True
    return False

def get_bond_angle(mol, bond_smi='CCC'):
    """
    Find bond pairs (defined by bond_smi) in mol and return the angle of the bond pair
    bond_smi: bond pair smiles, e.g. 'CCC'
    """
    deg_list = []
    substructure = Chem.MolFromSmiles(bond_smi)
    bond_pairs = mol.GetSubstructMatches(substructure)
    for pair in bond_pairs:
        deg_list += [GetAngleDeg(mol.GetConformer(), *pair)]
        assert mol.GetBondBetweenAtoms(pair[0], pair[1]) is not None
        assert mol.GetBondBetweenAtoms(pair[2], pair[1]) is not None
    return deg_list

def get_bond_symbol(bond):
    """
    Return the symbol representation of a bond
    """
    a0 = bond.GetBeginAtom().GetSymbol()
    a1 = bond.GetEndAtom().GetSymbol()
    b = str(int(bond.GetBondType())) # single: 1, double: 2, triple: 3, aromatic: 12
    return ''.join([a0, b, a1])

def get_triple_bonds(mol):
    """
    Get all the bond triplets in a molecule
    """
    valid_triple_bonds = []
    for idx_bond, bond in enumerate(mol.GetBonds()):
        idx_begin_atom = bond.GetBeginAtomIdx()
        idx_end_atom = bond.GetEndAtomIdx()
        begin_atom = mol.GetAtomWithIdx(idx_begin_atom)
        end_atom = mol.GetAtomWithIdx(idx_end_atom)
        begin_bonds = begin_atom.GetBonds()
        valid_left_bonds = []
        for begin_bond in begin_bonds:
            if begin_bond.GetIdx() == idx_bond:
                continue
            else:
                valid_left_bonds.append(begin_bond)
        if len(valid_left_bonds) == 0:
            continue

        end_bonds = end_atom.GetBonds()
        for end_bond in end_bonds:
            if end_bond.GetIdx() == idx_bond:
                continue
            else:
                for left_bond in valid_left_bonds:
                    valid_triple_bonds.append([left_bond, bond, end_bond])
    return valid_triple_bonds

def get_dihedral_angle(mol, bonds_ref_sym):
    """
    find bond triplets (defined by bonds_ref_sym) in mol and return the dihedral angle of the bond triplet
    bonds_ref_sym: a symbol string of bond triplet, e.g. 'C1C-C1C-C1C'
    """
    # bonds_ref_sym = '-'.join(get_bond_symbol(bonds_ref))
    bonds_list = get_triple_bonds(mol)
    angles_list = []
    for bonds  in bonds_list:
        sym = '-'.join([get_bond_symbol(b) for b in bonds])
        sym1 = '-'.join([get_bond_symbol(b) for b in bonds][::-1])
        atoms = []
        if (sym == bonds_ref_sym) or (sym1 == bonds_ref_sym): 
            if (sym1 == bonds_ref_sym):
                bonds = bonds[::-1]
            bond0 = bonds[0]
            atom0 = bond0.GetBeginAtomIdx()
            atom1 = bond0.GetEndAtomIdx()

            bond1 = bonds[1]
            atom1_0 = bond1.GetBeginAtomIdx()
            atom1_1 = bond1.GetEndAtomIdx()
            if atom0 == atom1_0:
                i, j, k =atom1, atom0, atom1_1
            elif atom0 == atom1_1:
                i, j, k =atom1, atom0, atom1_0
            elif atom1 == atom1_0:
                i, j, k =atom0, atom1, atom1_1
            elif atom1 == atom1_1:
                i, j, k =atom0, atom1, atom1_0
                
            bond2 = bonds[2]
            atom2_0 = bond2.GetBeginAtomIdx()
            atom2_1 = bond2.GetEndAtomIdx()
            if atom2_0 == k:
                l = atom2_1
            elif atom2_1 == k:
                l = atom2_0
            # print(i,j,k,l)
            angle = Chem.rdMolTransforms.GetDihedralDeg(mol.GetConformer(), i,j,k,l)
            angles_list.append(angle)
    return angles_list

def get_bond_length(mol, atom1, atom2, bt):
    bond_lengths = []
    for bond in mol.GetBonds():
        begin_atom = bond.GetBeginAtom()
        end_atom = bond.GetEndAtom()
        conf = mol.GetConformer()
        bond_type = bond.GetBondType()
        if bond_type == bt:
            if (begin_atom.GetSymbol() == atom1 and end_atom.GetSymbol() == atom2) or (begin_atom.GetSymbol() == atom2 and end_atom.GetSymbol() == atom1):
                bond_length = GetBondLength(conf, begin_atom.GetIdx(), end_atom.GetIdx())
                bond_lengths.append(bond_length)
    return bond_lengths