In [208]:
import fragmenter
import json
from openeye import oechem, oequacpac, oedepict, oegraphsim, oegrapheme
import matplotlib.pyplot as plt
import glob
import seaborn as sbn
import cmiles
import itertools
import numpy as np
import oenotebook as oenb

In [277]:
def mmd_x_xsqred(x, y):
    """
    Maximum mean discrepancy with squared kernel
    This will distinguish mean and variance
    see https://stats.stackexchange.com/questions/276497/maximum-mean-discrepancy-distance-distribution
    Parameters
    ----------
    x : list of ints
    y : list of ints

    Returns
    -------
    mmd score

    """

    y_arr = np.asarray(y)
    y_squared = y_arr*y_arr
    x_arr = np.asarray(x)
    x_squared = np.square(x_arr)

    E_x = np.mean(x_arr)
    E_y = np.mean(y_arr)

    E_x_squared = np.mean(x_squared)
    E_y_squared = np.mean(y_squared)

    mmd2 = np.sqrt((E_x - E_y)**2 + (E_x_squared - E_y_squared)**2)
    return mmd2

def get_bond(mol, bond_idx):
    a1 = mol.GetAtom(oechem.OEHasMapIdx(bond_idx[0]))
    a2 = mol.GetAtom(oechem.OEHasMapIdx(bond_idx[1]))
    bond = mol.GetBond(a1, a2)
    if not bond:
        raise ValueError("({}) atoms are not connected".format(bond_idx))
    return bond


def rbg_to_int(rbg, alpha):
    """
    Convert rbg color to ints for openeye
    Parameters
    ----------
    rbg : list
        rbg
    alpha : int

    Returns
    -------
    list of ints

    """
    rbg[-1] = int(rbg[-1]*alpha)
    colors = [int(round(i*255)) for i in rbg[:-1]]
    colors.append(int(rbg[-1]))
    return colors

def visualize_bond_sensitivity(mols, bonds, scores, fname, wbos, rows, cols, height=600, width=600):
    itf = oechem.OEInterface()
    ropts = oedepict.OEReportOptions(rows, cols)
    ropts.SetHeaderHeight(0.01)
    ropts.SetFooterHeight(0.01)
    ropts.SetCellGap(0.001)
    ropts.SetPageMargins(0.01)
    report = oedepict.OEReport(ropts)
    
    cellwidth, cellheight = report.GetCellWidth(), report.GetCellHeight()
    print(cellheight, cellheight)
    opts = oedepict.OE2DMolDisplayOptions(cellwidth, cellheight, oedepict.OEScale_AutoScale)
    oedepict.OESetup2DMolDisplayOptions(opts, itf)
    opts.SetAromaticStyle(oedepict.OEAromaticStyle_Circle)
    opts.SetAtomColorStyle(oedepict.OEAtomColorStyle_WhiteMonochrome)

    pen = oedepict.OEPen(oechem.OEBlack, oechem.OEBlack, oedepict.OEFill_Off, 0.9)
    opts.SetDefaultBondPen(pen)
    oedepict.OESetup2DMolDisplayOptions(opts, itf)
    
    minscale = float("inf")
    for m in mols:
        oedepict.OEPrepareDepiction(m, False, True)
        minscale = min(minscale, oedepict.OEGetMoleculeScale(m, opts))

    opts.SetScale(minscale)
    print(minscale)
    for i, mol in enumerate(mols):
        cell = report.NewCell()
        oedepict.OEPrepareDepiction(mol, False, True)
        atom_bond_sets = []
        wbo = wbos[i]
        for j, bond in enumerate(bonds[i]):
            bo = get_bond(mol, bond)
            bo.SetData('WibergBondOrder', wbo[j])
            atom_bond_set = oechem.OEAtomBondSet()
            #atom_bond_set.AddAtoms([bo.GetBgn(), bo.GetEnd()])
            atom_bond_set.AddBond(bo)
            atom_bond_sets.append(atom_bond_set)
    #dopt = oedepict.OEPrepareDepictionOptions()
    #oedepict.OEPrepareAlignedDepiction(mol, dopt)
    
       # opts = oedepict.OE2DMolDisplayOptions(width, height, oedepict.OEScale_AutoScale)
        opts.SetBondPropertyFunctor(fragmenter.chemi.LabelWibergBondOrder())
        opts.SetTitleLocation(oedepict.OETitleLocation_Hidden)

        disp = oedepict.OE2DMolDisplay(mol, opts)
        #aroStyle = oedepict.OEHighlightStyle_Stick
        #aroColor = oechem.OEColor(oechem.OEBlack)
        #oedepict.OEAddHighlighting(disp, aroColor, aroStyle,
        #                           oechem.OEIsAromaticAtom(), oechem.OEIsAromaticBond() )
        #highlight = oedepict.OEHighlightOverlayByBallAndStick(oechem.OE)
        hstyle = oedepict.OEHighlightStyle_Stick
        hstyle_2 = oedepict.OEHighlightStyle_Color
        #print(highlight)
        score = scores[i]
        norm = plt.Normalize(0, max(score))
        colors = plt.cm.coolwarm(norm(score))
        colors_oe = [rbg_to_int(c, 200) for c in colors]
        
        for j, atom_bond_set in enumerate(atom_bond_sets):
            highlight = oechem.OEColor(*colors_oe[j])
            #oedepict.OEAddHighlightOverlay(disp, highlight, atom_bond_set)

            oedepict.OEAddHighlighting(disp, highlight, hstyle, atom_bond_set)
            oedepict.OEAddHighlighting(disp, highlight, hstyle_2, atom_bond_set)
    
        oedepict.OERenderMolecule(cell, disp)
        #oedepict.OEDrawCurvedBorder(cell, oedepict.OELightGreyPen, 10.0)
    
    return oedepict.OEWriteReport(fname, report)

def visualize_bond_atom_sensitivity(mols, bonds, atoms, scores, fname, wbos, rows, cols, height=600, width=600):
    
    itf = oechem.OEInterface()
    ropts = oedepict.OEReportOptions(rows, cols)
    ropts.SetHeaderHeight(0.01)
    ropts.SetFooterHeight(0.01)
    ropts.SetCellGap(0.001)
    ropts.SetPageMargins(0.01)
    report = oedepict.OEReport(ropts)
    
    cellwidth, cellheight = report.GetCellWidth(), report.GetCellHeight()
    print(cellheight, cellheight)
    opts = oedepict.OE2DMolDisplayOptions(cellwidth, cellheight, oedepict.OEScale_AutoScale)
    oedepict.OESetup2DMolDisplayOptions(opts, itf)
    opts.SetAromaticStyle(oedepict.OEAromaticStyle_Circle)
    opts.SetAtomColorStyle(oedepict.OEAtomColorStyle_WhiteMonochrome)

    pen = oedepict.OEPen(oechem.OEBlack, oechem.OEBlack, oedepict.OEFill_Off, 0.9)
    opts.SetDefaultBondPen(pen)
    oedepict.OESetup2DMolDisplayOptions(opts, itf)
    
    minscale = float("inf")
    for m in mols:
        oedepict.OEPrepareDepiction(m, False, True)
        minscale = min(minscale, oedepict.OEGetMoleculeScale(m, opts))

    
    opts.SetScale(minscale)
    print(minscale)
    for i, mol in enumerate(mols):
        cell = report.NewCell()
        oedepict.OEPrepareDepiction(mol, False, True)
        atom_bond_sets = []
        wbo = wbos[i]
        for j, bond in enumerate(bonds[i]):
            bo = get_bond(mol, bond)
            bo.SetData('WibergBondOrder', wbo[j])
            atom_bond_set = oechem.OEAtomBondSet()
            #atom_bond_set.AddAtoms([bo.GetBgn(), bo.GetEnd()])
            atom_bond_set.AddBond(bo)
            atom_bond_sets.append(atom_bond_set)
    #dopt = oedepict.OEPrepareDepictionOptions()
    #oedepict.OEPrepareAlignedDepiction(mol, dopt)
    
       # opts = oedepict.OE2DMolDisplayOptions(width, height, oedepict.OEScale_AutoScale)
        opts.SetBondPropertyFunctor(fragmenter.chemi.LabelWibergBondOrder())
        opts.SetTitleLocation(oedepict.OETitleLocation_Hidden)

        disp = oedepict.OE2DMolDisplay(mol, opts)
        #aroStyle = oedepict.OEHighlightStyle_Stick
        #aroColor = oechem.OEColor(oechem.OEBlack)
        #oedepict.OEAddHighlighting(disp, aroColor, aroStyle,
        #                           oechem.OEIsAromaticAtom(), oechem.OEIsAromaticBond() )
        #highlight = oedepict.OEHighlightOverlayByBallAndStick(oechem.OE)
        hstyle = oedepict.OEHighlightStyle_Stick
        hstyle_2 = oedepict.OEHighlightStyle_Color
        #print(highlight)
        score = scores[i]
        norm = plt.Normalize(0, max(score))
        colors = plt.cm.coolwarm(norm(score))
        colors_oe = [rbg_to_int(c, 200) for c in colors]
        
        for j, atom_bond_set in enumerate(atom_bond_sets):
            highlight = oechem.OEColor(*colors_oe[j])
            #oedepict.OEAddHighlightOverlay(disp, highlight, atom_bond_set)

            oedepict.OEAddHighlighting(disp, highlight, hstyle, atom_bond_set)
            oedepict.OEAddHighlighting(disp, highlight, hstyle_2, atom_bond_set)
    
        #oedepict.OERenderMolecule(cell, disp)
        highlight = oedepict.OEHighlightByCogwheel(oechem.OEDarkPurple)
        highlight.SetBallRadiusScale(5.0)
        
        #hstyle_3 = oedepict.OEHighlightStyle_Cogwheel
        #hstyle_3.SetRadius(3.0)
        for atom in atoms[i]:
            color = oechem.OEColor(*colors_oe[atom[-1]])
            highlight.SetColor(color)
            atom = mol.GetAtom(oechem.OEHasMapIdx(atom[0]))
            atom_bond_set_a = oechem.OEAtomBondSet()
            atom_bond_set_a.AddAtom(atom)
            oedepict.OEAddHighlighting(disp, highlight, atom_bond_set_a)
        oedepict.OERenderMolecule(cell, disp)
        #oedepict.OEDrawCurvedBorder(cell, oedepict.OELightGreyPen, 10.0)
    
    return oedepict.OEWriteReport(fname, report)

def get_color_gradient(mol, tagname, ncolor, pcolor):
    """
    Generates color gradient.

    :type mol: oechem.OEMolBase
    :type tagname: string
    :type ncolor: oechem.OEColor
    :type pcolor: oechem.OEColor
    :rtype: oechem.OELineraColorGradient
    """

    minvalue, maxvalue = get_min_max_atom_property(mol, tagname)
    mid = (minvalue + maxvalue)/2
    print(mid)
    colorg = oechem.OELinearColorGradient(oechem.OEColorStop(0.0, oechem.OEWhite))
    colorg.AddStop(oechem.OEColorStop(minvalue, ncolor))
    #colorg.
    if minvalue < mid:
        colorg.AddStop(oechem.OEColorStop(minvalue, ncolor))
    if maxvalue > mid:
        colorg.AddStop(oechem.OEColorStop(maxvalue, pcolor))

    return colorg


def depict_atom_property_atomglyph(disp, tagname, colorg):
    """
    Depicts atom property using atom glyph style.

    :type disp: oedepict.OE2DMolDisplay
    :type tagname: string
    :type colorg: oechem.OELinearColorGradient
    """
    tag = oechem.OEGetTag(tagname)
    mol = disp.GetMolecule()

    for atom in mol.GetAtoms():
        if atom.HasData(tag):
            value = atom.GetDoubleData(tag)
            color = colorg.GetColorAt(value)
            pen = oedepict.OEPen(color, color, oedepict.OEFill_Off, 3.0)
            glyph = oegrapheme.OEAtomGlyphCircle(pen, oegrapheme.OECircleStyle_Default, 1.2)
            oegrapheme.OEAddGlyph(disp, glyph, oechem.OEHasAtomIdx(atom.GetIdx()))
            
def get_min_max_atom_property(mol, tagname):
    """
    Calculates minimum and maximum atom property values.

    :type mol: oedepict.OEMolBase
    :type tagname: string
    :rtype: (int, int)
    """

    minvalue = float("inf")
    maxvalue = float(0)

    tag = oechem.OEGetTag(tagname)

    for bond in mol.GetBonds():
        if bond.HasData(tag):
            val = bond.GetData(tag)
            minvalue = min(minvalue, val)
            maxvalue = max(maxvalue, val)
            
    return minvalue, maxvalue

            

def depict_bond_property_atomglyph(disp, tagname, colorg):
    """
    Depicts atom property using atom glyph style.

    :type disp: oedepict.OE2DMolDisplay
    :type tagname: string
    :type colorg: oechem.OELinearColorGradient
    """
    tag = oechem.OEGetTag(tagname)
    mol = disp.GetMolecule()
    hstyle = oedepict.OEHighlightStyle_Stick
    hstyle_2 = oedepict.OEHighlightStyle_Color
    #atom_bond_sets = []
    for bond in mol.GetBonds():
        if bond.HasData(tag):
            atom_bond_set = oechem.OEAtomBondSet()
            atom_bond_set.AddBond(bond)
            value = bond.GetDoubleData(tag)
            highlight = colorg.GetColorAt(value)
            oedepict.OEAddHighlighting(disp, highlight, hstyle, atom_bond_set)
            oedepict.OEAddHighlighting(disp, highlight, hstyle_2, atom_bond_set)
            #pen = oedepict.OEPen(color, color, oedepict.OEFill_Off, 3.0)
            #glyph = oegrapheme.OEBondGlyphCircle(pen, oegrapheme.OECircleStyle_Default, 1.2)
            #oegrapheme.OEAddGlyph(disp, glyph, oechem.OEHasAtomIdx(atom.GetIdx()))

In [159]:
mol_names = glob.glob('../../combinatorial_fragmentation/rank_fragments/selected/*')
panel_1_mols = ['Acemetacin_0', 'Ademetionine_0', 'Almitrine_1', 'Amlodipine_0',
               'Bosutinib_0', 'Ceftazidime_0', 'Eltrombopag_1','Sulfinpyrazone_0', 'Fostamatinib_0']#, 'Tedizolid_phosphate_0']#, 'Dacomitinib_0']#, 'Furosemide_0']#, 'Fostamatinib_0', 
              # 'Nitazoxanide_0']#, 'Nizatidine_0', 'Sulfoxone_0', 'Safinamide_0', 'Proguanil_7']

In [197]:
panel_2_mols = ['Dacomitinib_0', 'Tedizolid_phosphate_0', 'Nizatidine_0']

In [166]:
all_bonds = []
all_scores = []
all_scores_flat = []
all_elf10_wbos = []
lengths = []
all_smiles = []
names = []
for name in panel_1_mols:
    #name = name_dir.split('/')[-1]
    with open('../../combinatorial_fragmentation/rank_fragments/selected/{}/{}_oe_wbo_with_score.json'.format(name, name), 'r') as f:
        results = json.load(f)
    
    bonds = []
    scores = []
    elf10_wbos = []
    for bond in results:
        bond_des = fragmenter.utils.deserialize_bond(bond)
        for key in results[bond]:
            if 'parent' in key:
                parent_key = key
                parent_wbos = results[bond][parent_key]['individual_confs']
                elf10_wbos.append(results[bond][parent_key]['elf_estimate'])
                parent_smiles = results[bond][parent_key]['map_to_parent']
                continue
        score = 0
        for key in results[bond]:
            score = max(score, mmd_x_xsqred(parent_wbos, results[bond][key]['individual_confs']))
        bonds.append(bond_des) 
        scores.append(score)   
    all_bonds.append(bonds)
    all_scores.append(scores)
    all_scores_flat.extend(scores)
    all_smiles.append(parent_smiles)
    all_elf10_wbos.append(elf10_wbos)
    if len(lengths) == 0:
        lengths.append(len(bonds))
    else:
        lengths.append(len(bonds) + lengths[-1])
    names.append(name)
    

In [167]:
max(all_scores_flat)

0.39362985181674176

In [162]:
norm = plt.Normalize(0, max(all_scores_flat))
colors = plt.cm.coolwarm(norm(all_scores_flat))
colors_oe = [rbg_to_int(c, 255) for c in colors]

In [163]:
mols = []
for i, smiles in enumerate(all_smiles):
    mol = oechem.OEMol()
    oechem.OESmilesToMol(mol, smiles)
    mol.SetTitle(names[i])
    mols.append(mol)

In [164]:
visualize_bond_sensitivity(mols, bonds=all_bonds, scores=all_scores, wbos=all_elf10_wbos, 
                           rows=4, cols=3, fname='test_2.pdf')

197.98875 197.98875
14.07614752506787


True

In [235]:
all_bonds[0].index((28, 17))

6

In [278]:
panel_2_mols = ['Dacomitinib_0', 'Tedizolid_phosphate_0', 'Nizatidine_0']
atoms = [[(29,(28, 17)), (30, (14, 27))], [], []]

all_bonds = []
all_scores = []
all_scores_flat = []
all_elf10_wbos = []
lengths = []
all_smiles = []
names = []
for name in panel_2_mols:
    #name = name_dir.split('/')[-1]
    with open('../../combinatorial_fragmentation/rank_fragments/selected/{}/{}_oe_wbo_with_score.json'.format(name, name), 'r') as f:
        results = json.load(f)
    
    bonds = []
    scores = []
    elf10_wbos = []
    for bond in results:
        bond_des = fragmenter.utils.deserialize_bond(bond)
        for key in results[bond]:
            if 'parent' in key:
                parent_key = key
                parent_wbos = results[bond][parent_key]['individual_confs']
                elf10_wbos.append(results[bond][parent_key]['elf_estimate'])
                parent_smiles = results[bond][parent_key]['map_to_parent']
                continue
        score = 0
        for key in results[bond]:
            score = max(score, mmd_x_xsqred(parent_wbos, results[bond][key]['individual_confs']))
        bonds.append(bond_des) 
        scores.append(score)   
    all_bonds.append(bonds)
    all_scores.append(scores)
    all_scores_flat.extend(scores)
    all_smiles.append(parent_smiles)
    all_elf10_wbos.append(elf10_wbos)
    if len(lengths) == 0:
        lengths.append(len(bonds))
    else:
        lengths.append(len(bonds) + lengths[-1])
    names.append(name)
norm = plt.Normalize(0, max(all_scores_flat))
colors = plt.cm.coolwarm(norm(all_scores_flat))
colors_oe = [rbg_to_int(c, 255) for c in colors] 

atoms = [[(29,all_bonds[0].index((28, 17))), (30, all_bonds[0].index((27, 14)))], 
         [], []]

mols = []
for i, smiles in enumerate(all_smiles):
    mol = oechem.OEMol()
    oechem.OESmilesToMol(mol, smiles)
    mol.SetTitle(names[i])
    # Add scores to bond
    for j, bond in enumerate(all_bonds[i]):
        bo = get_bond(mol, bond)
        bo.SetData('score', all_scores[i][j])
    #for j, atom in enumerate(atoms[i]):
    #    bo = get_bond(mol, atom[-1])
    #    score = bo.GetData('score')
    #    a = mol.GetAtom(oechem.OEHasMapIdx(atom[0]))
    #    a.SetData('score', score)
    mols.append(mol)

visualize_bond_atom_sensitivity(mols, bonds=all_bonds, atoms=atoms, scores=all_scores, wbos=all_elf10_wbos, 
                           rows=4, cols=3, fname='test_3.pdf')

197.98875 197.98875
14.270786498238088


True

[[(10, 28),
  (11, 31),
  (16, 24),
  (17, 15),
  (24, 29),
  (27, 14),
  (28, 17),
  (9, 27)],
 [(11, 12), (15, 17), (17, 29), (29, 31), (8, 7), (9, 23)],
 [(10, 16),
  (11, 15),
  (12, 11),
  (15, 5),
  (2, 9),
  (21, 12),
  (3, 10),
  (4, 17),
  (5, 14),
  (9, 21)]]

In [209]:
gradients = get_color_gradient(mol=mols[0], ncolor=oechem.OEColor('blue'), pcolor=oechem.OEColor('red'), tagname='score')



In [195]:
oechem.OEColorg

<oechem.OELinearColorGradient; proxy of <Swig Object of type 'OESystem::OELinearColorGradient *' at 0x1a2531e060> >

In [222]:
all_scores[0]

[0.07715367450985172,
 0.06476712244986732,
 0.01762664296729302,
 0.07833043114523443,
 0.03099894778381171,
 0.10631671755389016,
 0.19568849600550278,
 0.045782717681051674]