In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import time
import numpy as np
import torch
from torch.nn import Dropout, ELU
import torch.nn.functional as F
from torch import nn
from dgl.nn.pytorch import GATConv
import itertools 
import dgl
from collections import defaultdict as ddict, Counter
from tqdm import tqdm
import pandas as pd
import qgrid


np.set_printoptions(suppress=True, formatter={'float_kind':'{:0.4f}'.format})

Using backend: pytorch


In [6]:
class GATOptimized(nn.Module):
    def __init__(self,
                 in_dim,
                 hidden_dim,
                 out_dim,
                 num_layers=1,
                 heads=8,
                 activation=F.elu,
                 feat_drop=.6,
                 attn_drop=.6,
                 negative_slope=.2,
                 residual=False):
        super(GATOptimized, self).__init__()
        self.num_layers = num_layers
        self.gat_layers = nn.ModuleList()
        self.activation = activation
        # input projection (no residual)
        self.gat_layers.append(GATConv(
            in_dim, hidden_dim, heads,
            feat_drop, attn_drop, negative_slope, False, self.activation))
        # hidden layers
        for l in range(1, num_layers):
            # due to multi-head, the in_dim = num_hidden * num_heads
            self.gat_layers.append(GATConv(
                hidden_dim * heads, hidden_dim, heads,
                feat_drop, attn_drop, negative_slope, residual, self.activation))
        # output projection
        self.gat_layers.append(GATConv(
            hidden_dim * heads, out_dim, 1,
            feat_drop, attn_drop, negative_slope, residual, None))

    def forward(self, graph, inputs):
        h = inputs
        g = graph
        for l in range(self.num_layers):
            h = self.gat_layers[l](g, h).flatten(1)
        # output projection
        logits = self.gat_layers[-1](g, h).mean(1)
        return logits

In [61]:
import networkx as nx

g = nx.read_gml('1qbr-good-2.gml', label='id')

nodes = g.nodes(data=True)
node_data = [tuple(v[1].values()) for v in nodes]
edges = g.edges(data=True)
edge_data = [tuple(v[2].values()) for v in edges]

np.savez_compressed('graph.npz', edges=g.edges(), edge_data=edge_data, node_data=node_data)

In [9]:
import os 
a = np.load('02-pdbbind-refined-1k/1a28-0-0.npz')

In [17]:
dir = '02-pdbbind-refined-1k/'
for fn in os.listdir(dir):
    if fn.endswith('npz'):
        break
a = np.load(dir + fn)

In [None]:
dgl.gra

In [20]:
for k in a.keys():
    print(k)

edges
edge_data
node_data


In [51]:
graph_data = np.load('graph.npz')

In [55]:
graph_data['node_data']

array([[131,   0,   0,   0],
       [134,   0,   0,   0],
       [137,  17,   0,   0],
       ...,
       [ 63,   6,   1,   1],
       [ 90,   6,   1,   1],
       [ 91,   6,   1,   1]])

In [39]:
edge_data = []
covalent = [0, 1]
not_covalent = [1, 0]
for (u,v, data) in edges:
    if data['is_covalent'] == 1:
        edge_data.append([data['distance']] + covalent)
    else:
        edge_data.append([data['distance']] + not_covalent)

In [40]:
edge_data

[[1.52293, 0, 1],
 [2.45799, 1, 0],
 [2.59391, 1, 0],
 [3.69126, 1, 0],
 [4.24222, 1, 0],
 [3.91086, 1, 0],
 [5.23243, 1, 0],
 [4.73409, 1, 0],
 [4.7936, 1, 0],
 [5.63425, 1, 0],
 [6.68185, 1, 0],
 [5.88907, 1, 0],
 [4.54636, 1, 0],
 [5.99108, 1, 0],
 [7.60622, 1, 0],
 [7.40035, 1, 0],
 [6.26832, 1, 0],
 [7.79231, 1, 0],
 [6.92206, 1, 0],
 [6.73016, 1, 0],
 [6.21713, 1, 0],
 [7.54696, 1, 0],
 [7.9836, 1, 0],
 [7.94636, 1, 0],
 [7.42263, 1, 0],
 [7.71776, 1, 0],
 [6.96069, 1, 0],
 [6.85712, 1, 0],
 [6.89619, 1, 0],
 [6.68482, 1, 0],
 [5.4409, 1, 0],
 [4.42873, 1, 0],
 [5.73659, 1, 0],
 [7.32484, 1, 0],
 [6.65852, 1, 0],
 [4.39128, 1, 0],
 [4.20958, 1, 0],
 [4.19883, 1, 0],
 [3.76942, 1, 0],
 [3.53628, 1, 0],
 [4.07779, 1, 0],
 [7.89357, 1, 0],
 [1.46685, 0, 1],
 [2.14763, 1, 0],
 [2.43566, 1, 0],
 [2.78878, 1, 0],
 [2.44419, 1, 0],
 [3.78813, 1, 0],
 [3.66008, 1, 0],
 [3.97469, 1, 0],
 [4.46199, 1, 0],
 [5.91089, 1, 0],
 [5.28404, 1, 0],
 [4.17445, 1, 0],
 [5.17826, 1, 0],
 [6.95617, 1,

In [15]:
g.order(), g.size()

(466, 18686)

In [3]:
df = pd.read_csv('02-pdbbind-refined.csv')

In [4]:
df

Unnamed: 0,pdb,run,pose,nfrb,e_docking,rmsd,eLJ,emetal,eHB,eelec,etors,is_good,e_exp,name
0,184l,0,0,2,-5.32687,0.55437,-1.41440,0.00000,0.20600,-0.00003,0.2352,1,-6.444844,184l-0-0.gml
1,184l,0,1,2,-5.15635,1.80368,-1.34504,0.00000,0.26100,-0.00049,0.2352,0,-6.444844,184l-0-1.gml
2,184l,0,2,2,-5.03622,4.19935,-1.37091,0.00000,0.36300,0.00039,0.2352,0,-6.444844,184l-0-2.gml
3,184l,0,3,2,-5.05864,4.24878,-1.40106,0.00000,0.37800,0.00002,0.2352,0,-6.444844,184l-0-3.gml
4,184l,0,4,2,-5.12396,4.24426,-1.41376,0.00000,0.21300,-0.00005,0.2352,0,-6.444844,184l-0-4.gml
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
82789,6equ,9,1,7,-7.02665,8.24941,-2.21915,0.00000,2.69746,0.21076,0.8232,0,-8.711462,6equ-9-1.gml
82790,6equ,9,2,7,-9.65971,1.81683,-3.37676,-1.47861,1.94918,0.14543,0.8232,0,-8.711462,6equ-9-2.gml
82791,6equ,9,3,7,-6.47944,8.34522,-1.83162,0.00000,2.38618,0.03469,0.8232,0,-8.711462,6equ-9-3.gml
82792,6equ,9,4,7,-8.50094,3.69017,-2.97835,-0.72403,1.61687,0.13728,0.8232,0,-8.711462,6equ-9-4.gml
