In [1]:
import pybel
import openbabel
import numpy as np
from mayavi import mlab
import tqdm

import mol
import grid
import trainer

import os
import gzip

Using TensorFlow backend.


In [2]:
%load_ext autoreload
%autoreload 2

In [64]:
d = trainer.load_training_data()

100%|██████████| 102/102 [00:28<00:00,  4.59it/s]


In [65]:
len(d)

85

In [66]:
a = trainer.grid_generator(d, 2, 20, 1).next()

In [58]:
g = a[0]['target_in'][0]

In [71]:
a[1]['fragment_out'].shape

(2, 20, 20, 20)

In [94]:
targets = [x for x in os.listdir('../data/docked_dude/') if x[0] != '.']
cooh = pybel.Smarts('[CX3](=O)[OX2H1]')

# targets = targets[:1]

# temporary file to store uncompressed sdf data
tmp_file = '../data/tmp.sdf'

# reset mol directory

for t in targets:
    # find actives file
    a_file = os.path.join('../data/docked_dude/', t, 'actives_final_docked_vina.sdf.gz')
    
    # decompress
    with gzip.open(a_file) as f:
        with open(tmp_file, 'w') as w:
            w.write(f.read())
        
    # read active molecules
    actives = [x for x in pybel.readfile('sdf', tmp_file)]
    
    mols = []
    
    # filter by functional group
    for active in actives:
        matches = cooh.findall(active)
            
        if len(matches) > 0:
            for m in matches:
                mols.append(mol.Mol.from_pybel(active, m))
                
    print('%s : found %d matches' % (t, len(mols)))
    
    # save molecules
    m_file = os.path.join('../data/mol/', t + '.dat')
    mol.Mol.writefile(m_file, mols)

aa2ar : found 36 matches
abl1 : found 27 matches
ace : found 8469 matches
aces : found 18 matches
ada : found 0 matches
ada17 : found 63 matches
adrb1 : found 405 matches
adrb2 : found 135 matches
akt1 : found 15 matches
akt2 : found 9 matches
aldr : found 1030 matches
ampc : found 648 matches
andr : found 32 matches
aofb : found 9 matches
bace1 : found 45 matches
braf : found 9 matches
cah2 : found 1098 matches
casp3 : found 1863 matches
cdk2 : found 0 matches
comt : found 45 matches
cp2c9 : found 204 matches
cp3a4 : found 81 matches
csf1r : found 18 matches
cxcr4 : found 0 matches
def : found 0 matches
dhi1 : found 145 matches
dpp4 : found 333 matches
drd3 : found 0 matches
dyr : found 3519 matches
egfr : found 45 matches
esr1 : found 41 matches
esr2 : found 79 matches
fa10 : found 392 matches
fa7 : found 765 matches
fabp4 : found 492 matches
fak1 : found 36 matches
fgfr1 : found 36 matches
fkb1a : found 36 matches
fnta : found 2492 matches
fpps : found 0 matches
gcr : found 24 match

In [54]:
mols = []

targets = [x for x in os.listdir('../data/mol/') if x[0] != '.']

for t in targets:
    mols += mol.Mol.readfile(os.path.join('../data/mol/', t))

In [257]:
# atom types to consider
ATOM_TYPES = [
    6, 7, 8, 9, 15, 16, 17, 35, 53
]

# mapping
# atomic number -> van der waals radius
VDW_RADIUS = {
    6: 1.9,
    7: 1.8,
    8: 1.7,
    9: 1.5,
    15: 2.1,
    16: 2.0,
    17: 1.8,
    35: 2.0,
    53: 2.2
}

In [258]:
import grid

In [266]:
g = grid.generate_grid_cpu(mols[20000].atoms, ATOM_TYPES, VDW_RADIUS, mols[20000].center(), 24, 1)

In [268]:
aa2ar_pdb = [x for x in pybel.readfile('pdb', '../data/dude/aa2ar/receptor.pdb')][0]

In [269]:
aa2ar_mol = mol.Mol.from_pybel(aa2ar_pdb)

In [270]:
g = grid.generate_grid_cpu(aa2ar_mol.atoms, ATOM_TYPES, VDW_RADIUS, aa2ar_mol.center(), 24, 1)

In [None]:
pybel.readfile('pdb', '../data/dude/aa2ar/receptor.pdb').next()

In [53]:
s = np.sum(g,axis=3)
src = mlab.pipeline.scalar_field(s)
mlab.pipeline.iso_surface(src, contours=[s.min()+0.1*s.ptp(), ], opacity=0.1)
mlab.pipeline.iso_surface(src, contours=[s.max()-0.1*s.ptp(), ],)
mlab.pipeline.image_plane_widget(src,
                            plane_orientation='z_axes',
                            slice_index=10,
                        )
mlab.show()

In [271]:
x = []
y = []
z = []
s = []
c = []


for p in np.ndindex(g.shape[:-1]):
    for t in range(g.shape[-1]):
        if g[p[0], p[1], p[2], t] != 0:
            x.append(p[0])
            y.append(p[1])
            z.append(p[2])
            s.append(g[p[0], p[1], p[2], t])
            c.append(t)
        
pts = mlab.quiver3d(x, y, z, s, s, s, scalars=c, mode='sphere')
pts.glyph.color_mode = 'color_by_scalar'
pts.glyph.glyph_source.glyph_source.center = [0, 0, 0]

mlab.show()

In [66]:
import model

In [70]:
m = model.MolPredictModel()

______________________________________________________________________________________________________________
Layer (type)                        Output Shape            Param #      Connected to                         
ligand_in (InputLayer)              (None, 20, 20, 20, 9)   0                                                 
______________________________________________________________________________________________________________
target_in (InputLayer)              (None, 20, 20, 20, 9)   0                                                 
______________________________________________________________________________________________________________
concatenate_9 (Concatenate)         (None, 20, 20, 20, 18)  0            ligand_in[0][0]                      
                                                                         target_in[0][0]                      
______________________________________________________________________________________________________________
c

In [75]:
m.train()

  0%|          | 0/1 [00:00<?, ?it/s]

Training targets:
- ace
Test targets:
- wee1
Loading training data...


100%|██████████| 1/1 [00:01<00:00,  1.15s/it]
100%|██████████| 1/1 [00:00<00:00,  6.75it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Loading testing data...
Generating testing grid data...


100%|██████████| 10/10 [00:21<00:00,  2.15s/it]


Training...
Epoch 1/1


In [109]:
m.model.fit(dataset[0], dataset[1], steps_per_epoch=100)

Epoch 1/1


<keras.callbacks.History at 0x14ac32890>

In [23]:
mols = mol.Mol.readfile('../data/mol/ace.dat')

In [24]:
target_pdb = (pybel.readfile('pdb', '../data/dude/ace/receptor.pdb')).next()
target_mol = mol.Mol.from_pybel(target_pdb)

In [86]:
data = [(target_mol, [mols[0]])]

dataset = trainer.full_grid_generator(data, 20, 1)

100%|██████████| 1/1 [00:03<00:00,  3.91s/it]


In [87]:
X = dataset

In [110]:
o = m.model.predict(X[0])[0]

In [111]:
a = np.sum(X[0]['ligand_in'][0], axis=3)
b = np.sum(X[0]['target_in'][0], axis=3)
d = np.sum(X[1]['fragment_out'][0], axis=3)
f2 = np.reshape(o, [20,20,20])
f = ((f2 - np.min(f2)) / (np.max(f2) - np.min(f2)))

z = np.copy(f)
z[np.where(f < 0.3)] = 0

print(z.shape)

g = np.stack([a,d,z], axis=3)

(20, 20, 20)


In [113]:
np.min(f2)

0.31661603

In [168]:
g = np.reshape(np.sum(X['target_in'][0], axis=3), [20,20,20,1])

In [115]:
x = []
y = []
z = []
s = []
c = []

# (x,y,z) = np.where(f >= np.max(f) * 0.2)
# x = list(x)
# y = list(y)
# z = list(z)
# s = [2] * len(x)
# c = [2] * len(x)

for p in np.ndindex(g.shape[:-1]):
    for t in range(g.shape[-1]):
        if g[p[0], p[1], p[2], t] != 0:
            x.append(p[0])
            y.append(p[1])
            z.append(p[2])
            s.append(g[p[0], p[1], p[2], t])
#             s.append(1)
            c.append(t)
        
pts = mlab.quiver3d(x, y, z, s, s, s, scalars=c, mode='sphere')
pts.glyph.color_mode = 'color_by_scalar'
pts.glyph.glyph_source.glyph_source.center = [0, 0, 0]

mlab.show()