In [1]:
import pickle
import re
import numpy as np
import sys
import os
import glob
import networkx as nx
import torch
import torch_geometric
import random
import yaml
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import remove_isolated_nodes
from torch import nn
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn import SAGPooling
from torch_geometric.nn import MLP
from torch_geometric.nn import AttentiveFP
from torch_geometric.nn.aggr import AttentionalAggregation
from copy import deepcopy 


import torch
import torch.nn.functional as F

In [2]:
with open('/xdisk/twheeler/jgaiser/deepvs/deepvs/data/protein_config.yaml', 'r') as config_file:
    protein_config = yaml.safe_load(config_file) 

corpus_dir = "/xdisk/twheeler/jgaiser/deepvs/deepvs/data/graph_data/corpus/" 

PROTEIN_ATOM_LABELS = protein_config['atom_labels']
PROTEIN_EDGE_LABELS = protein_config['edge_labels']

dummy_index = PROTEIN_ATOM_LABELS.index('DUMMY')
voxel_edge_index = PROTEIN_EDGE_LABELS.index('voxel')

In [3]:
def fetch_decoy_batch(decoy_count):
    random_idx = random.randint(1,75)
    decoy_corpus = pickle.load(open(corpus_dir + "pdbbind_corpus_%s_75.pkl" % random_idx, 'rb'))
    return next(iter(DataLoader(decoy_corpus[1], batch_size=decoy_count, shuffle=True)))

In [4]:
## POCKET GRAPH GCN
pocket_hidden = 512 
NODE_DIMS = 41 
EDGE_DIMS = 9 
DUMMY_INDEX = protein_config['atom_labels'].index('DUMMY')

edge_weight_modifier = 1
MAX_EDGE_WEIGHT = 15.286330223083496 / edge_weight_modifier

class PocketGCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(NODE_DIMS-3,pocket_hidden)
        
        self.conv1 = GCN2Conv(pocket_hidden, 0.2)
        self.conv2 = GCN2Conv(pocket_hidden, 0.2)
        self.conv3 = GCN2Conv(pocket_hidden, 0.2)

    def forward(self, data):
        x, edge_index, edge_weights = data.x[:,:-3], data.edge_index, data.edge_attr[:,-1] / MAX_EDGE_WEIGHT 

        x = self.linear1(x)
    
        h = self.conv1(x, x, edge_index, edge_weights)
        h = F.relu(h)

        h = self.conv2(h, x, edge_index, edge_weights)
        h = F.relu(h)

        h = self.conv3(h, x, edge_index, edge_weights)
        h = F.relu(h)
        
        return h

In [5]:
## MOLECULE GRAPH POOLING 
#@title Molecule Pooling

mol_hidden = 512 

class MoleculePool(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = AttentiveFP(in_channels=53, 
                                 hidden_channels=mol_hidden, 
                                 out_channels=mol_hidden,
                                 edge_dim=10,
                                 num_layers=5,
                                 num_timesteps=2,
                                 dropout=0.0)
        
    def forward(self, data):
        x, edge_index, edge_weights, batch = data.x, data.edge_index, data.edge_attr, data.batch
        h = self.conv1(x, edge_index, edge_weights, batch)
        return h

In [6]:
## POCKET GRAPH POOLING
pool_hidden = 512 

class PoxelPool(torch.nn.Module):
    def __init__(self):
        super().__init__()

#         self.conv1 = GCN2Conv(pool_hidden, 0.1)
#         self.conv2 = GCN2Conv(pool_hidden, 0.1)

        self.pool1 = SAGPooling(pool_hidden)

        self.conv3 = GCN2Conv(pool_hidden, 0.2)
        self.conv4 = GCN2Conv(pool_hidden, 0.2)

        self.pool2 = SAGPooling(pool_hidden)

        gate_nn = MLP([pool_hidden, 1], act='relu')
        nn = MLP([pool_hidden, pool_hidden], act='relu')
        self.global_pool = AttentionalAggregation(gate_nn, nn)

    def forward(self, data):
        x, edge_index, edge_weights, batch = data.x, data.edge_index, data.edge_attr[:,-1], data.batch
        
#         h = self.conv1(x, x, edge_index, edge_weights)
#         h = F.relu(h)

#         h = self.conv2(h, x, edge_index, edge_weights)
#         h = F.relu(h)

        h, edge_index, edge_weights, batch, _, _ = self.pool1(x, edge_index, edge_weights, batch)

        h = self.conv3(h, h, edge_index, edge_weights)
        h = F.relu(h)

        h, edge_index, edge_weights, batch, _, _ = self.pool2(h, edge_index, edge_weights, batch)

        h = self.conv4(h, h, edge_index, edge_weights)
        h = F.relu(h)
        
        h = self.global_pool(h, index=batch)
        return h    

In [8]:
# ACTIVE CLASSIFIER
#@title Classifier

class ActiveClassifier(torch.nn.Module):
    def __init__(self, pocket_model, poxel_model, molecule_model):
        super(ActiveClassifier, self).__init__()
        self.pocket_model = PocketGCN()
        self.pox_pooler = PoxelPool()
        self.mol_pooler = MoleculePool()

#         self.linear1 = nn.Linear(2048, 1024)
        self.linear1 = nn.Linear(1024, 512)
        self.linear2 = nn.Linear(512, 512)
        self.linear3 = nn.Linear(512, 1)

#         self.linear1 = nn.Linear(512, 256)
#         self.linear2 = nn.Linear(256, 64)
#         self.linear3 = nn.Linear(64, 1)
#         self.linear4 = nn.Linear(256, 1)

        self.relu = nn.ReLU()

    def forward(self, pocket_batch, active_batch, decoy_batch):
#         voxel_node_indices = torch.where(pocket_batch.x[:, dummy_index] == 1.0)[0]
#         non_voxel_node_indices = torch.where(pocket_batch.x[:, dummy_index] == 0)[0]
#         voxel_edge_indices = torch.where(pocket_batch.edge_attr[:, voxel_edge_index]==1.0)[0]
        
        pocket_embeds = self.pocket_model(pocket_batch)
        
#         pocket_embeds[non_voxel_node_indices] = torch.zeros(pocket_embeds.size(1))
#         trimmed_edge_index = torch.vstack((pocket_batch.edge_index[0][voxel_edge_indices],
#                                            pocket_batch.edge_index[1][voxel_edge_indices]))
        
#         trimmed_edge_attr = pocket_batch.edge_attr[voxel_edge_indices]
        
        pocket_batch.x = pocket_embeds
#         pocket_batch.edge_index = trimmed_edge_index
#         pocket_batch.edge_attr = trimmed_edge_attr
        
        poxel_embeds = self.pox_pooler(pocket_batch)
        active_embeds = self.mol_pooler(active_batch)
        decoy_embeds = self.mol_pooler(decoy_batch)
        
        poxel_actives = torch.hstack((poxel_embeds, active_embeds))
        poxel_decoys = torch.hstack((torch.cat([poxel_embeds]*len(decoy_embeds), dim=0), 
                                     decoy_embeds.repeat_interleave(poxel_embeds.size(0), dim=0)))
        
        all_embeds = torch.vstack((poxel_actives, poxel_decoys))

        x = self.linear1(all_embeds) 
        x = self.relu(x)
        
        x = self.linear2(x) 
        x = self.relu(x)
        
        x = self.linear3(x) 
#         x = self.relu(x)
        
#         o = self.linear4(x) 
        return x
    
    
# batch_size=32
# ac = ActiveClassifier(PocketGCN, PoxelPool, MoleculePool)

# corpus = pickle.load(open(corpus_dir + "pdbbind_corpus_3_75.pkl", 'rb'))
# mol_loader = DataLoader(corpus[1], batch_size=batch_size, shuffle=False)

# pocket_loader = DataLoader(corpus[2], batch_size=batch_size, shuffle=False)
# decoy_batch = fetch_decoy_batch(200)

# for active_batch, pocket_batch in zip(mol_loader, pocket_loader):
#     out = ac(pocket_batch, active_batch, decoy_batch)
#     print(out)
#     break

In [None]:
# TRAINING LOOP
sigmoid = nn.Sigmoid()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([100])).to(device)
# criterion = nn.BCEWithLogitsLoss().to(device)
ac = ActiveClassifier(PocketGCN, PoxelPool, MoleculePool).to(device)
optimizer = torch.optim.Adam(ac.parameters(), lr=1e-3)

def training_loop(epochs, batch_size):
    for epoch in range(epochs):
        print("EPOCH %s" % epoch)
        
        for corpus_idx in range(1,75):
            loss_history = []
            corpus = pickle.load(open(corpus_dir + "pdbbind_corpus_%s_75.pkl" % corpus_idx, 'rb'))
            
            mol_loader = DataLoader(corpus[1], batch_size=batch_size, shuffle=False)
            pocket_loader = DataLoader(corpus[2], batch_size=batch_size, shuffle=False)
            
            for active_batch, pocket_batch in zip(mol_loader, pocket_loader):
                optimizer.zero_grad() 
                active_batch = active_batch.to(device)
                pocket_batch = pocket_batch.to(device)
                decoy_batch = fetch_decoy_batch(batch_size).to(device)
                
                out = ac(pocket_batch, active_batch, decoy_batch)
                
                y = torch.zeros(out.size(0))
                y[:batch_size] = 1
                y = torch.unsqueeze(y, dim=1).to(device)
                
                loss = criterion(out, y)
                loss_history.append(loss.item())
                loss.backward()
                optimizer.step()
                
            print(sum(loss_history) / len(loss_history))
            print('active:', [float("%.3f" % x) for x in torch.sigmoid(out[:7].squeeze()).tolist()])
            print('decoy:', [float("%.3f" % x) for x in torch.sigmoid(out[-7:].squeeze()).tolist()])
            print('')
    
#       for batch_idx in range(int(POXEL_COLLECTION_SIZE / batch_size)):
#         y = torch.hstack([torch.ones(batch_size), torch.zeros(batch_size*np_ratio)]).unsqueeze(dim=1).to(device) 
      
#         pox_batch, active_batch, decoy_batch, active_indices = retrieve_training_batch(batch_size, np_ratio, poxel_collection, col_idx)
    
#         # pox_batch, active_batch, decoy_batch, active_indices = retrieve_DUMMY_batch(batch_size, np_ratio, poxel_collection, COL_IDX, [Ameans, Astds], [Bmeans, Bstds])
#         out = ac(pox_batch.to(device), active_batch.to(device), decoy_batch.to(device), np_ratio)
#         loss = criterion(out, y)
#         print(batch_idx, torch.mean(sigmoid(out[0:10])).item(), torch.mean(sigmoid(out[10:])).item(), loss.item())

        # print(loss)

training_loop(100, 32)


EPOCH 0
2.994863669077555
active: [0.809, 0.8, 0.8, 0.816, 0.848, 0.849, 0.849]
decoy: [0.79, 0.79, 0.789, 0.837, 0.838, 0.838, 0.838]

3.8560456964704724
active: [0.65, 0.625, 0.632, 0.636, 0.649, 0.649, 0.649]
decoy: [0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65]

2.785764561759101
active: [0.803, 0.794, 0.802, 0.774, 0.785, 0.785, 0.79]
decoy: [0.804, 0.809, 0.815, 0.803, 0.804, 0.809, 0.815]

3.0209287537468805
active: [0.794, 0.816, 0.816, 0.808, 0.783, 0.783, 0.783]
decoy: [0.791, 0.791, 0.791, 0.801, 0.801, 0.801, 0.801]

2.783065266079373
active: [0.822, 0.817, 0.812, 0.825, 0.817, 0.817, 0.817]
decoy: [0.818, 0.818, 0.818, 0.818, 0.818, 0.818, 0.818]

2.807522243923611
active: [0.818, 0.818, 0.819, 0.819, 0.821, 0.821, 0.821]
decoy: [0.822, 0.822, 0.822, 0.819, 0.819, 0.819, 0.819]

2.780528253979153
active: [0.816, 0.822, 0.822, 0.822, 0.817, 0.817, 0.817]
decoy: [0.816, 0.816, 0.816, 0.818, 0.818, 0.818, 0.818]

2.762437343597412
active: [0.825, 0.822, 0.822, 0.829, 0.831, 0.831

2.8457711272769504
active: [0.858, 0.859, 0.858, 0.859, 0.859, 0.859, 0.858]
decoy: [0.859, 0.859, 0.859, 0.859, 0.859, 0.859, 0.859]

2.896095355351766
active: [0.869, 0.869, 0.869, 0.869, 0.869, 0.869, 0.869]
decoy: [0.869, 0.869, 0.869, 0.869, 0.869, 0.869, 0.869]

2.870506180657281
active: [0.867, 0.867, 0.867, 0.867, 0.867, 0.867, 0.867]
decoy: [0.867, 0.867, 0.867, 0.867, 0.867, 0.867, 0.867]

2.872984674241808
active: [0.853, 0.853, 0.852, 0.853, 0.853, 0.853, 0.853]
decoy: [0.852, 0.853, 0.853, 0.853, 0.853, 0.853, 0.853]

2.884339544508192
active: [0.868, 0.868, 0.868, 0.868, 0.868, 0.868, 0.868]
decoy: [0.868, 0.868, 0.868, 0.868, 0.868, 0.868, 0.868]

2.8651338948143854
active: [0.866, 0.866, 0.866, 0.866, 0.866, 0.866, 0.866]
decoy: [0.866, 0.866, 0.866, 0.866, 0.866, 0.866, 0.866]

3.0862305694156222
active: [0.854, 0.854, 0.854, 0.854, 0.854, 0.854, 0.854]
decoy: [0.854, 0.854, 0.854, 0.854, 0.854, 0.854, 0.854]

2.8718938297695584
active: [0.895, 0.895, 0.895, 0.895, 0.8

2.8374483585357666
active: [0.86, 0.86, 0.86, 0.86, 0.86, 0.86, 0.86]
decoy: [0.86, 0.86, 0.86, 0.86, 0.86, 0.86, 0.86]

2.890724923875597
active: [0.873, 0.873, 0.873, 0.874, 0.874, 0.874, 0.874]
decoy: [0.873, 0.873, 0.873, 0.873, 0.873, 0.873, 0.873]

2.857821994357639
active: [0.865, 0.865, 0.865, 0.865, 0.865, 0.865, 0.866]
decoy: [0.865, 0.865, 0.865, 0.865, 0.865, 0.865, 0.865]

2.8962657716539173
active: [0.868, 0.868, 0.868, 0.868, 0.868, 0.868, 0.868]
decoy: [0.869, 0.868, 0.868, 0.868, 0.868, 0.868, 0.868]

2.87149969736735
active: [0.859, 0.859, 0.854, 0.859, 0.859, 0.854, 0.859]
decoy: [0.854, 0.859, 0.859, 0.854, 0.859, 0.859, 0.854]

2.8444112406836615
active: [0.86, 0.86, 0.86, 0.86, 0.86, 0.86, 0.86]
decoy: [0.86, 0.86, 0.86, 0.86, 0.86, 0.86, 0.86]

2.950733184814453
active: [0.831, 0.832, 0.832, 0.831, 0.832, 0.832, 0.831]
decoy: [0.832, 0.831, 0.832, 0.832, 0.831, 0.832, 0.832]

4.75485701031155
active: [0.606, 0.61, 0.608, 0.606, 0.61, 0.608, 0.606]
decoy: [0.608, 

3.1140112347073026
active: [0.932, 0.918, 0.888, 0.965, 0.965, 0.965, 0.904]
decoy: [0.91, 0.89, 0.89, 0.89, 0.884, 0.884, 0.884]

3.540081739425659
active: [0.578, 0.714, 0.723, 0.779, 0.779, 0.779, 0.708]
decoy: [0.802, 0.707, 0.707, 0.707, 0.712, 0.712, 0.712]

2.7870369222429066
active: [0.913, 0.961, 0.963, 0.879, 0.879, 0.879, 0.897]
decoy: [0.922, 0.893, 0.893, 0.893, 0.927, 0.927, 0.927]

3.105933321846856
active: [0.901, 0.864, 0.83, 0.836, 0.836, 0.836, 0.865]
decoy: [0.831, 0.833, 0.833, 0.833, 0.831, 0.831, 0.831]

2.906083001030816
active: [0.839, 0.827, 0.858, 0.817, 0.817, 0.817, 0.816]
decoy: [0.816, 0.847, 0.847, 0.847, 0.819, 0.819, 0.819]

2.8215960926479764
active: [0.858, 0.858, 0.88, 0.858, 0.858, 0.858, 0.855]
decoy: [0.88, 0.862, 0.862, 0.862, 0.874, 0.874, 0.874]

2.89404559135437
active: [0.857, 0.855, 0.854, 0.855, 0.855, 0.855, 0.857]
decoy: [0.854, 0.855, 0.855, 0.855, 0.859, 0.859, 0.859]

2.901286416583591
active: [0.872, 0.872, 0.872, 0.87, 0.87, 0.87, 0

2.856575118170844
active: [0.841, 0.843, 0.84, 0.842, 0.842, 0.842, 0.842]
decoy: [0.861, 0.861, 0.861, 0.861, 0.843, 0.843, 0.843]

2.8075962596469455
active: [0.866, 0.862, 0.863, 0.862, 0.862, 0.862, 0.865]
decoy: [0.867, 0.865, 0.865, 0.865, 0.863, 0.863, 0.863]

2.844169749153985
active: [0.869, 0.86, 0.862, 0.865, 0.865, 0.865, 0.864]
decoy: [0.865, 0.937, 0.937, 0.937, 0.929, 0.929, 0.929]

3.080356650882297
active: [0.855, 0.854, 0.861, 0.861, 0.861, 0.861, 0.852]
decoy: [0.862, 0.855, 0.855, 0.855, 0.868, 0.868, 0.868]

3.566678603490194
active: [0.888, 0.888, 0.895, 0.895, 0.895, 0.895, 0.89]
decoy: [0.906, 0.889, 0.889, 0.889, 0.886, 0.886, 0.886]

3.391104989581638
active: [0.949, 0.957, 0.958, 0.943, 0.943, 0.943, 0.943]
decoy: [0.95, 0.955, 0.955, 0.955, 0.943, 0.943, 0.943]

3.3614697721269398
active: [0.919, 0.922, 0.918, 0.944, 0.944, 0.944, 0.92]
decoy: [0.918, 0.916, 0.916, 0.916, 0.915, 0.915, 0.915]

3.1179155243767633
active: [0.806, 0.816, 0.752, 0.821, 0.821, 0.

2.712531010309855
active: [0.862, 0.855, 0.856, 0.861, 0.856, 0.856, 0.856]
decoy: [0.854, 0.854, 0.854, 0.877, 0.877, 0.877, 0.877]

3.635613441467285
active: [0.863, 0.861, 0.889, 0.889, 0.889, 0.889, 0.858]
decoy: [0.857, 0.857, 0.857, 0.857, 0.857, 0.857, 0.857]

2.5441407203674316
active: [0.882, 0.883, 0.881, 0.881, 0.882, 0.866, 0.867]
decoy: [0.866, 0.866, 0.866, 0.866, 0.866, 0.866, 0.866]

2.773923291100396
active: [0.823, 0.82, 0.824, 0.822, 0.822, 0.822, 0.822]
decoy: [0.84, 0.84, 0.84, 0.821, 0.821, 0.821, 0.821]

2.8648468653361
active: [0.83, 0.823, 0.822, 0.827, 0.827, 0.827, 0.823]
decoy: [0.823, 0.887, 0.887, 0.887, 0.831, 0.831, 0.831]

2.782422754499647
active: [0.853, 0.861, 0.861, 0.909, 0.909, 0.909, 0.852]
decoy: [0.855, 0.855, 0.855, 0.855, 0.853, 0.853, 0.853]

2.823685222201877
active: [0.864, 0.867, 0.864, 0.865, 0.865, 0.865, 0.864]
decoy: [0.871, 0.864, 0.864, 0.864, 0.867, 0.867, 0.867]

2.9102793799506292
active: [0.87, 0.869, 0.878, 0.87, 0.87, 0.87, 0.

2.7808980147043862
active: [0.818, 0.869, 0.834, 0.852, 0.813, 0.813, 0.813]
decoy: [0.834, 0.834, 0.834, 0.832, 0.832, 0.832, 0.832]

2.6870834032694497
active: [0.851, 0.838, 0.838, 0.848, 0.853, 0.853, 0.853]
decoy: [0.854, 0.854, 0.854, 0.853, 0.853, 0.853, 0.853]

2.715130647023519
active: [0.8, 0.854, 0.857, 0.763, 0.856, 0.856, 0.856]
decoy: [0.854, 0.854, 0.854, 0.855, 0.854, 0.854, 0.855]

2.6986666785346136
active: [0.86, 0.863, 0.85, 0.85, 0.85, 0.85, 0.85]
decoy: [0.85, 0.85, 0.85, 0.85, 0.85, 0.85, 0.85]

2.7053103711869984
active: [0.838, 0.845, 0.845, 0.845, 0.838, 0.838, 0.838]
decoy: [0.845, 0.845, 0.845, 0.841, 0.841, 0.841, 0.841]

2.7122231324513755
active: [0.839, 0.834, 0.834, 0.844, 0.833, 0.833, 0.833]
decoy: [0.835, 0.835, 0.835, 0.837, 0.837, 0.837, 0.837]

2.691950480143229
active: [0.846, 0.845, 0.836, 0.842, 0.835, 0.835, 0.835]
decoy: [0.84, 0.84, 0.84, 0.839, 0.839, 0.839, 0.839]

2.73573891321818
active: [0.851, 0.836, 0.842, 0.8, 0.832, 0.832, 0.832]
de

2.846588955985175
active: [0.858, 0.856, 0.862, 0.857, 0.857, 0.857, 0.86]
decoy: [0.864, 0.859, 0.859, 0.859, 0.869, 0.869, 0.87]

2.8304280704922147
active: [0.854, 0.854, 0.855, 0.855, 0.855, 0.855, 0.855]
decoy: [0.855, 0.855, 0.855, 0.855, 0.854, 0.854, 0.854]

2.8190395567152233
active: [0.856, 0.857, 0.887, 0.857, 0.856, 0.857, 0.857]
decoy: [0.857, 0.857, 0.857, 0.857, 0.857, 0.857, 0.857]

2.811163478427463
active: [0.864, 0.864, 0.865, 0.863, 0.863, 0.863, 0.865]
decoy: [0.865, 0.863, 0.863, 0.863, 0.864, 0.864, 0.864]

2.826466586854723
active: [0.864, 0.865, 0.864, 0.865, 0.865, 0.865, 0.865]
decoy: [0.864, 0.865, 0.865, 0.865, 0.864, 0.864, 0.865]

2.8093716303507485
active: [0.861, 0.864, 0.864, 0.85, 0.85, 0.85, 0.859]
decoy: [0.859, 0.86, 0.86, 0.86, 0.859, 0.859, 0.859]

2.829172372817993
active: [0.852, 0.858, 0.858, 0.863, 0.863, 0.863, 0.861]
decoy: [0.858, 0.862, 0.862, 0.862, 0.867, 0.867, 0.867]

2.820903195275201
active: [0.869, 0.868, 0.862, 0.853, 0.853, 0.853

2.8274370299445257
active: [0.867, 0.866, 0.863, 0.864, 0.864, 0.865, 0.852]
decoy: [0.861, 0.859, 0.859, 0.859, 0.869, 0.869, 0.87]

2.83912369940016
active: [0.862, 0.861, 0.868, 0.861, 0.861, 0.861, 0.869]
decoy: [0.863, 0.872, 0.872, 0.872, 0.869, 0.868, 0.868]

2.835159699122111
active: [0.861, 0.86, 0.855, 0.851, 0.851, 0.852, 0.867]
decoy: [0.867, 0.867, 0.867, 0.867, 0.866, 0.866, 0.866]

2.905297729704115
active: [0.81, 0.818, 0.838, 0.841, 0.841, 0.841, 0.853]
decoy: [0.849, 0.841, 0.841, 0.841, 0.848, 0.848, 0.848]

2.8254474798838296
active: [0.847, 0.845, 0.846, 0.863, 0.863, 0.863, 0.856]
decoy: [0.831, 0.819, 0.819, 0.819, 0.821, 0.821, 0.821]

2.8115982479519315
active: [0.875, 0.875, 0.874, 0.882, 0.882, 0.882, 0.874]
decoy: [0.875, 0.875, 0.875, 0.875, 0.874, 0.874, 0.874]

2.863069030973646
active: [0.868, 0.868, 0.867, 0.867, 0.867, 0.867, 0.867]
decoy: [0.868, 0.87, 0.87, 0.87, 0.867, 0.867, 0.867]

2.842922819985284
active: [0.856, 0.856, 0.856, 0.86, 0.86, 0.86, 

2.8702912595536976
active: [0.868, 0.862, 0.862, 0.871, 0.871, 0.871, 0.874]
decoy: [0.869, 0.87, 0.87, 0.87, 0.872, 0.872, 0.872]

2.8567949401007757
active: [0.831, 0.83, 0.849, 0.852, 0.852, 0.852, 0.843]
decoy: [0.851, 0.85, 0.85, 0.85, 0.845, 0.845, 0.845]

2.8001910050710044
active: [0.84, 0.844, 0.833, 0.86, 0.86, 0.86, 0.861]
decoy: [0.832, 0.87, 0.87, 0.87, 0.857, 0.857, 0.858]

3.7351188394758434
active: [0.862, 0.862, 0.841, 0.859, 0.859, 0.86, 0.861]
decoy: [0.866, 0.866, 0.866, 0.866, 0.853, 0.853, 0.853]

3.0037394629584417
active: [0.946, 0.937, 0.939, 0.938, 0.938, 0.938, 0.937]
decoy: [0.942, 0.941, 0.941, 0.941, 0.94, 0.939, 0.939]

3.7699496746063232
active: [0.695, 0.695, 0.695, 0.684, 0.684, 0.684, 0.679]
decoy: [0.682, 0.68, 0.68, 0.68, 0.686, 0.686, 0.686]

3.561355617311266
active: [0.824, 0.831, 0.831, 0.766, 0.766, 0.766, 0.69]
decoy: [0.908, 0.845, 0.846, 0.846, 0.773, 0.774, 0.774]

3.008241891860962
active: [0.898, 0.9, 0.901, 0.903, 0.903, 0.903, 0.899]
de

2.8247345553504095
active: [0.854, 0.86, 0.859, 0.859, 0.859, 0.859, 0.86]
decoy: [0.857, 0.858, 0.858, 0.858, 0.859, 0.858, 0.858]

2.808758099873861
active: [0.86, 0.862, 0.851, 0.859, 0.858, 0.859, 0.86]
decoy: [0.865, 0.85, 0.85, 0.85, 0.861, 0.861, 0.861]

2.847352902094523
active: [0.868, 0.866, 0.857, 0.858, 0.858, 0.858, 0.841]
decoy: [0.856, 0.806, 0.806, 0.806, 0.856, 0.856, 0.856]

2.8175590568118625
active: [0.87, 0.864, 0.863, 0.876, 0.876, 0.876, 0.878]
decoy: [0.877, 0.862, 0.862, 0.862, 0.868, 0.868, 0.869]

2.8102514213985867
active: [0.871, 0.882, 0.872, 0.882, 0.882, 0.882, 0.882]
decoy: [0.873, 0.845, 0.845, 0.845, 0.881, 0.881, 0.881]

2.897636651992798
active: [0.872, 0.874, 0.856, 0.861, 0.861, 0.861, 0.862]
decoy: [0.866, 0.868, 0.868, 0.868, 0.872, 0.872, 0.872]

2.840685897403293
active: [0.859, 0.873, 0.852, 0.871, 0.871, 0.871, 0.856]
decoy: [0.878, 0.853, 0.853, 0.853, 0.864, 0.864, 0.864]

2.8275651137034097
active: [0.87, 0.868, 0.866, 0.87, 0.87, 0.87, 0

2.5651001930236816
active: [0.873, 0.873, 0.872, 0.872, 0.873, 0.862, 0.863]
decoy: [0.86, 0.86, 0.86, 0.86, 0.86, 0.86, 0.86]

2.7595749696095786
active: [0.835, 0.836, 0.833, 0.833, 0.836, 0.836, 0.836]
decoy: [0.84, 0.84, 0.84, 0.836, 0.835, 0.835, 0.835]

2.868927107916938
active: [0.843, 0.831, 0.829, 0.837, 0.836, 0.836, 0.828]
decoy: [0.83, 0.838, 0.837, 0.837, 0.829, 0.829, 0.829]



In [None]:
for idx, target_dir in enumerate(root_dir):
    
    if idx < random.randint(0,100):
        continue
        
    print(target_dir)
    target_id = target_dir.split('/')[-2]
    data = pickle.load(open("%s%s_ligand_graph.pkl" % (target_dir, target_id), 'rb'))
    print(data.edge_attr.shape)
    break
    
#     g = torch_geometric.utils.to_networkx(data, to_undirected=True)
#     nx.draw(g)
#     break