In [1]:
import pickle
import re
import numpy as np
import sys
import os
import glob
import networkx as nx
import torch
import torch_geometric
import random
import yaml
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import remove_isolated_nodes
from torch import nn
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn import SAGPooling
from torch_geometric.nn import MLP
from torch_geometric.nn import AttentiveFP
from torch_geometric.nn.aggr import AttentionalAggregation
from copy import deepcopy 

import torch
import torch.nn.functional as F

In [2]:
with open('/xdisk/twheeler/jgaiser/deepvs/deepvs/data/protein_config.yaml', 'r') as config_file:
    protein_config = yaml.safe_load(config_file) 

PROTEIN_ATOM_LABELS = protein_config['atom_labels']
PROTEIN_EDGE_LABELS = protein_config['edge_labels']
INTERACTION_LABELS = protein_config['interaction_labels']
PDB_IDS = []

dummy_index = PROTEIN_ATOM_LABELS.index('DUMMY')
voxel_edge_index = PROTEIN_EDGE_LABELS.index('voxel')

pocket_graph_dir = '/xdisk/twheeler/jgaiser/deepvs/deepvs/data/graph_data/pockets/'
mol_graph_dir = '/xdisk/twheeler/jgaiser/deepvs/deepvs/data/graph_data/molecules/'

pocket_file_structure = pocket_graph_dir + "%s_pocket_graph.pkl"
mol_file_structure = mol_graph_dir + "%s_mol_graph.pkl"

In [3]:
PDB_IDS = []

for item in glob.glob(pocket_graph_dir + "*"):
    PDB_IDS.append(item.split('/')[-1].split('_')[0])

PDB_IDS = np.array(PDB_IDS)

In [4]:
def fetch_training_batch(pdb_ids, batch_size):
    sample_ids = np.random.choice(PDB_IDS, batch_size)
    pocket_graphs = [] 
    mol_graphs = []
    decoy_graphs = []
    
    for s_id in sample_ids:
        pocket_graphs.append(pickle.load(open(pocket_file_structure % s_id, 'rb')))
        mol_graphs.append(pickle.load(open(mol_file_structure % s_id, 'rb')))
        decoy_graphs.append(pickle.load(open(mol_file_structure % np.random.choice(PDB_IDS), 'rb')))
    
    mol_graphs += decoy_graphs
    
    pocket_loader = DataLoader(pocket_graphs, batch_size=batch_size, shuffle=False)
    mol_loader = DataLoader(mol_graphs, batch_size=batch_size*2, shuffle=False)
    
    return next(iter(pocket_loader)), next(iter(mol_loader))


In [5]:
interaction_count = [0 for x in INTERACTION_LABELS]

for s_id in PDB_IDS:
    g = (pickle.load(open(mol_file_structure % s_id, 'rb')))
    interacting_voxel_indices = torch.where(torch.sum(g.y, dim=1) > 0)[0]
    for idx in torch.where(g.y[interacting_voxel_indices] > 0)[1]:
        interaction_count[idx.item()] += 1
        
print(INTERACTION_LABELS)
print(interaction_count)
    

['halogenbond', 'hbond_a', 'hbond_d', 'hydroph_interaction', 'pication_c', 'pication_r', 'pistack', 'saltbridge_n', 'saltbridge_p']
[1299, 40523, 30777, 63876, 547, 7661, 36457, 6324, 2158]


In [6]:
interaction_weights = [1/(x/max(interaction_count)) for x in interaction_count]
print(interaction_weights)

[49.17321016166282, 1.5762900081435234, 2.075445949897651, 1.0, 116.77513711151735, 8.337814906670147, 1.7520915050607564, 10.10056925996205, 29.599629286376274]


In [7]:
# class Event:
#     def __init__(self, sr1=None, foobar=None):
#         self.sr1 = sr1
#         self.foobar = foobar
#         self.state = STATE_NON_EVENT
 
# # Event class wrappers to provide syntatic sugar
# class TypeTwoEvent(Event):
#     def __init__(self, level=None):
#         self.sr1 = level
#         self.state = STATE_EVENT_TWO
        
# class TypeTwoEvent(Event):
#     def __init__(self, level=None, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.sr1 = level
#         self.state = STATE_EVENT_TWO
from torch_geometric.nn import GATConv, MessagePassing, global_add_pool

class AtomicAttentiveFP(AttentiveFP):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.voxel_classifier = nn.Linear(kwargs['hidden_channels'], 9)
        
    def forward(self, x, edge_index, edge_attr, batch):
        """"""
        # Atom Embedding:
        x = F.leaky_relu_(self.lin1(x))

        h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr))
        h = F.dropout(h, p=self.dropout, training=self.training)
        x = self.atom_grus[0](h, x).relu_()

        for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]):
            h = F.elu_(conv(x, edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            x = gru(h, x).relu_()

        # Molecule Embedding:
        row = torch.arange(batch.size(0), device=batch.device)
        edge_index = torch.stack([row, batch], dim=0)

        out = global_add_pool(x, batch).relu_()
        
        for t in range(self.num_timesteps):
            h = F.elu_(self.mol_conv((x, out), edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            out = self.mol_gru(h, out).relu_()

        # Predictor:
        out = F.dropout(out, p=self.dropout, training=self.training)
        return self.voxel_classifier(x), self.lin2(out)

In [8]:
import inspect


print("".join(inspect.getsourcelines(AttentiveFP.forward)[0]))

    def forward(self, x, edge_index, edge_attr, batch):
        """"""
        # Atom Embedding:
        x = F.leaky_relu_(self.lin1(x))

        h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr))
        h = F.dropout(h, p=self.dropout, training=self.training)
        x = self.atom_grus[0](h, x).relu_()

        for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]):
            h = F.elu_(conv(x, edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            x = gru(h, x).relu_()

        # Molecule Embedding:
        row = torch.arange(batch.size(0), device=batch.device)
        edge_index = torch.stack([row, batch], dim=0)

        out = global_add_pool(x, batch).relu_()
        for t in range(self.num_timesteps):
            h = F.elu_(self.mol_conv((x, out), edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            out = self.mol_gru(h, out).relu_()

        # Predictor:
        out = F.dropout(out, p

In [9]:
## MOLECULE GRAPH POOLING 
#@title Molecule Pooling

mol_hidden = 512 

class MoleculePool(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = AtomicAttentiveFP(in_channels=52, 
                                 hidden_channels=mol_hidden, 
                                 out_channels=mol_hidden,
                                 edge_dim=10,
                                 num_layers=5,
                                 num_timesteps=5,
                                 dropout=0.0)
    
        
    
    def forward(self, data):
        x, edge_index, edge_weights, batch = data.x, data.edge_index, data.edge_attr, data.batch
        h = self.conv1(x, edge_index, edge_weights, batch)
        return h
    
model = MoleculePool()
pocket_batch, mol_batch = fetch_training_batch(PDB_IDS, 32)

criterion = nn.BCEWithLogitsLoss()
atom_embeds, mol_embed = model(mol_batch)
interacting_voxel_indices = torch.where(torch.sum(mol_batch.y, dim=1) > 0)[0]
criterion(atom_embeds[interacting_voxel_indices], mol_batch.y[interacting_voxel_indices])

tensor(0.6906, dtype=torch.float64,
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

In [10]:
BATCH_SIZE = 32
sigmoid = nn.Sigmoid()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(interaction_weights)).to(device)
model = MoleculePool().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def mol_training_loop(epochs, batch_size):
    for epoch in range(epochs):
        print("EPOCH %s" % epoch)
        
        for batch_idx in range(int(len(PDB_IDS) / BATCH_SIZE)):
            loss_history = []
            pocket_batch, mol_batch = fetch_training_batch(PDB_IDS, 32)
            mol_batch = mol_batch.to(device)
            
            optimizer.zero_grad() 

            atom_embeds, mol_embed = model(mol_batch)

            interacting_voxel_indices = torch.where(torch.sum(mol_batch.y, dim=1) > 0)[0]
            loss = criterion(atom_embeds[interacting_voxel_indices], 
                             mol_batch.y[interacting_voxel_indices].to(device))
                               
            
            loss_history.append(loss.item())
            loss.backward()
            optimizer.step()
                    
            if batch_idx % 10 == 0:
                print(sum(loss_history) / len(loss_history))
                loss_history = []

# mol_training_loop(100, 32)

In [15]:
## POCKET GRAPH GCN
pocket_hidden = 512 
NODE_DIMS = 41 
EDGE_DIMS = 9 
DUMMY_INDEX = protein_config['atom_labels'].index('DUMMY')

edge_weight_modifier = 1
MAX_EDGE_WEIGHT = 15.286330223083496 / edge_weight_modifier

class PocketGCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(NODE_DIMS-3,pocket_hidden)
        
        self.conv1 = GCN2Conv(pocket_hidden, 0.2)
        self.conv2 = GCN2Conv(pocket_hidden, 0.2)
        self.conv3 = GCN2Conv(pocket_hidden, 0.2)
        
        self.voxel_prediction = nn.Linear(pocket_hidden,9)

    def forward(self, data):
        x, edge_index, edge_weights = data.x[:,:-3], data.edge_index, data.edge_attr[:,-1] / MAX_EDGE_WEIGHT 

        x = self.linear1(x)
    
        h = self.conv1(x, x, edge_index, edge_weights)
        h = F.relu(h)

        h = self.conv2(h, x, edge_index, edge_weights)
        h = F.relu(h)

        h = self.conv3(h, x, edge_index, edge_weights)
        h = F.relu(h)
        
        return self.voxel_prediction(h), h
        
        
model = PocketGCN()
pocket_batch, mol_batch = fetch_training_batch(PDB_IDS, 16)
criterion = nn.BCEWithLogitsLoss()
voxel_embeds, _ = model(pocket_batch)
interacting_voxel_indices = torch.where(torch.sum(pocket_batch.y, dim=1) > 0)[0]
criterion(voxel_embeds[interacting_voxel_indices], pocket_batch.y[interacting_voxel_indices])
        

tensor(0.6919, dtype=torch.float64,
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

In [18]:
BATCH_SIZE = 16
sigmoid = nn.Sigmoid()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(interaction_weights)).to(device)
model = PocketGCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def mol_training_loop(epochs, batch_size):
    for epoch in range(epochs):
        print("EPOCH %s" % epoch)
        
        for batch_idx in range(int(len(PDB_IDS) / batch_size)):
            loss_history = []
            pocket_batch, mol_batch = fetch_training_batch(PDB_IDS, 32)
            pocket_batch = pocket_batch.to(device)
            
            optimizer.zero_grad() 

            voxel_embeds,_ = model(pocket_batch)

            interacting_voxel_indices = torch.where(torch.sum(pocket_batch.y, dim=1) > 0)[0]
            loss = criterion(voxel_embeds[interacting_voxel_indices], 
                             pocket_batch.y[interacting_voxel_indices].to(device))
                               
            
            loss_history.append(loss.item())
            loss.backward()
            optimizer.step()
                    
            if batch_idx % 10 == 0:
                print(sum(loss_history) / len(loss_history))
                loss_history = []

mol_training_loop(100, BATCH_SIZE)

EPOCH 0
0.8772057591976178
0.8814792667468496
0.7724420435185684
0.6317582452860929
0.7521522869624137
0.6236169924374215
0.6265521881615626
0.7595097036211796
0.7058004424531212
0.795169066033177
0.8426764223555572
0.8912911965793044
0.7254604871552458
0.5961641725174945
0.678087353566586
0.788790090883841
0.7639377332313084
0.7552595334430137
0.6195694806414231
0.5415449231132282
0.5644196772990164
0.8102796001289015
0.640316067583322
0.6115450794109032
0.5648700620592927
0.6466832170772769
0.533578156996897
0.5144788142505027
0.5524863108999009
0.5924881660431535
0.4917460778547627
0.6086022815085812
0.5535892681379121
0.5891871367752499
0.7498581800686909
0.5035398636149109
0.8458210700549625
0.574805321329095
0.4789835999375726
0.6299925963519087
0.5713546522633901
0.4814444000888872
0.531038688582531
0.5440264997758042
0.6602316471958343
0.578405628090927
0.5320423853722986
0.4563847244164438
0.4450537128741511
0.4277876474183357
0.6816472518592002
0.5447116970352348
0.7025937149

0.6707385126025881
0.3757787735083245
0.4218771783152082
0.4626161040309235
0.42271105843723905
0.45566118770347713
0.46555382887056684
0.4056316197766852
0.36266037034760257
0.3786612857483997
0.4395643615847561
0.3852659434125195
0.4391682265020668
0.42190751510786356
0.3982172193921508
0.4814595012451262
0.43877914238025634
0.33852612639957524
0.5549551048063579
0.40834342061553414
0.427680108906651
0.41569122175665213
0.4735632637067555
0.42376589563415584
0.4756962479475812
0.3951722191363597
0.43462292624832405
0.3985626261896211
0.43582411891571626
0.4098631232092347
0.3703038752973044
0.6346856518331364
0.48690468371163814
0.48513766187537277
0.3752594989726427
0.41343833770658256
0.4737911909382178
0.5774099741234883
0.40823907332113935
0.44669661287582546
0.37745653406699337
0.4252433818927323
0.44612697248639477
0.680957339219957
0.47204126601068624
0.3356392367901982
0.4487649942718764
0.3531289759144088
0.4332190813907578
0.3643096709336652
0.35986679789688286
0.3792770719

KeyboardInterrupt: 

In [None]:
## POCKET GRAPH POOLING
pool_hidden = 512 

class PoxelPool(torch.nn.Module):
    def __init__(self):
        super().__init__()

#         self.conv1 = GCN2Conv(pool_hidden, 0.1)
#         self.conv2 = GCN2Conv(pool_hidden, 0.1)

        self.pool1 = SAGPooling(pool_hidden)

        self.conv3 = GCN2Conv(pool_hidden, 0.2)
        self.conv4 = GCN2Conv(pool_hidden, 0.2)

        self.pool2 = SAGPooling(pool_hidden)

        gate_nn = MLP([pool_hidden, 1], act='relu')
        nn = MLP([pool_hidden, pool_hidden], act='relu')
        self.global_pool = AttentionalAggregation(gate_nn, nn)

    def forward(self, data):
        x, edge_index, edge_weights, batch = data.x, data.edge_index, data.edge_attr[:,-1], data.batch
        
#         h = self.conv1(x, x, edge_index, edge_weights)
#         h = F.relu(h)

#         h = self.conv2(h, x, edge_index, edge_weights)
#         h = F.relu(h)

        h, edge_index, edge_weights, batch, _, _ = self.pool1(x, edge_index, edge_weights, batch)

        h = self.conv3(h, h, edge_index, edge_weights)
        h = F.relu(h)

        h, edge_index, edge_weights, batch, _, _ = self.pool2(h, edge_index, edge_weights, batch)

        h = self.conv4(h, h, edge_index, edge_weights)
        h = F.relu(h)
        
        h = self.global_pool(h, index=batch)
        return h    
    
mol_batch.y[interacting_voxel_indices]

In [None]:
# ACTIVE CLASSIFIER
#@title Classifier

class ActiveClassifier(torch.nn.Module):
    def __init__(self, pocket_model, poxel_model, molecule_model):
        super(ActiveClassifier, self).__init__()
        self.pocket_model = PocketGCN()
        self.pox_pooler = PoxelPool()
        self.mol_pooler = MoleculePool()

#         self.linear1 = nn.Linear(2048, 1024)
        self.linear1 = nn.Linear(1024, 512)
        self.linear2 = nn.Linear(512, 512)
        self.linear3 = nn.Linear(512, 1)

#         self.linear1 = nn.Linear(512, 256)
#         self.linear2 = nn.Linear(256, 64)
#         self.linear3 = nn.Linear(64, 1)
#         self.linear4 = nn.Linear(256, 1)

        self.relu = nn.ReLU()

    def forward(self, pocket_batch, active_batch, decoy_batch):
#         voxel_node_indices = torch.where(pocket_batch.x[:, dummy_index] == 1.0)[0]
#         non_voxel_node_indices = torch.where(pocket_batch.x[:, dummy_index] == 0)[0]
#         voxel_edge_indices = torch.where(pocket_batch.edge_attr[:, voxel_edge_index]==1.0)[0]
        
        pocket_embeds = self.pocket_model(pocket_batch)
        
#         pocket_embeds[non_voxel_node_indices] = torch.zeros(pocket_embeds.size(1))
#         trimmed_edge_index = torch.vstack((pocket_batch.edge_index[0][voxel_edge_indices],
#                                            pocket_batch.edge_index[1][voxel_edge_indices]))
        
#         trimmed_edge_attr = pocket_batch.edge_attr[voxel_edge_indices]
        
        pocket_batch.x = pocket_embeds
#         pocket_batch.edge_index = trimmed_edge_index
#         pocket_batch.edge_attr = trimmed_edge_attr
        
        poxel_embeds = self.pox_pooler(pocket_batch)
        active_embeds = self.mol_pooler(active_batch)
        decoy_embeds = self.mol_pooler(decoy_batch)
        
        poxel_actives = torch.hstack((poxel_embeds, active_embeds))
        poxel_decoys = torch.hstack((torch.cat([poxel_embeds]*len(decoy_embeds), dim=0), 
                                     decoy_embeds.repeat_interleave(poxel_embeds.size(0), dim=0)))
        
        all_embeds = torch.vstack((poxel_actives, poxel_decoys))

        x = self.linear1(all_embeds) 
        x = self.relu(x)
        
        x = self.linear2(x) 
        x = self.relu(x)
        
        x = self.linear3(x) 
#         x = self.relu(x)
        
#         o = self.linear4(x) 
        return x
    
    
# batch_size=32
# ac = ActiveClassifier(PocketGCN, PoxelPool, MoleculePool)

# corpus = pickle.load(open(corpus_dir + "pdbbind_corpus_3_75.pkl", 'rb'))
# mol_loader = DataLoader(corpus[1], batch_size=batch_size, shuffle=False)

# pocket_loader = DataLoader(corpus[2], batch_size=batch_size, shuffle=False)
# decoy_batch = fetch_decoy_batch(200)

# for active_batch, pocket_batch in zip(mol_loader, pocket_loader):
#     out = ac(pocket_batch, active_batch, decoy_batch)
#     print(out)
#     break

In [None]:
for idx, target_dir in enumerate(root_dir):
    
    if idx < random.randint(0,100):
        continue
        
    print(target_dir)
    target_id = target_dir.split('/')[-2]
    data = pickle.load(open("%s%s_ligand_graph.pkl" % (target_dir, target_id), 'rb'))
    print(data.edge_attr.shape)
    break
    
#     g = torch_geometric.utils.to_networkx(data, to_undirected=True)
#     nx.draw(g)
#     break

In [None]:
# TRAINING LOOP
sigmoid = nn.Sigmoid()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([100])).to(device)
# criterion = nn.BCEWithLogitsLoss().to(device)
ac = ActiveClassifier(PocketGCN, PoxelPool, MoleculePool).to(device)
optimizer = torch.optim.Adam(ac.parameters(), lr=1e-3)

def training_loop(epochs, batch_size):
    for epoch in range(epochs):
        print("EPOCH %s" % epoch)
        
        for corpus_idx in range(1,75):
            loss_history = []
            corpus = pickle.load(open(corpus_dir + "pdbbind_corpus_%s_75.pkl" % corpus_idx, 'rb'))
            
            mol_loader = DataLoader(corpus[1], batch_size=batch_size, shuffle=False)
            pocket_loader = DataLoader(corpus[2], batch_size=batch_size, shuffle=False)
            
            for active_batch, pocket_batch in zip(mol_loader, pocket_loader):
                optimizer.zero_grad() 
                active_batch = active_batch.to(device)
                pocket_batch = pocket_batch.to(device)
                decoy_batch = fetch_decoy_batch(batch_size).to(device)
                
                out = ac(pocket_batch, active_batch, decoy_batch)
                
                y = torch.zeros(out.size(0))
                y[:batch_size] = 1
                y = torch.unsqueeze(y, dim=1).to(device)
                
                loss = criterion(out, y)
                loss_history.append(loss.item())
                loss.backward()
                optimizer.step()
                
            print(sum(loss_history) / len(loss_history))
            print('active:', [float("%.3f" % x) for x in torch.sigmoid(out[:7].squeeze()).tolist()])
            print('decoy:', [float("%.3f" % x) for x in torch.sigmoid(out[-7:].squeeze()).tolist()])
            print('')
    
#       for batch_idx in range(int(POXEL_COLLECTION_SIZE / batch_size)):
#         y = torch.hstack([torch.ones(batch_size), torch.zeros(batch_size*np_ratio)]).unsqueeze(dim=1).to(device) 
      
#         pox_batch, active_batch, decoy_batch, active_indices = retrieve_training_batch(batch_size, np_ratio, poxel_collection, col_idx)
    
#         # pox_batch, active_batch, decoy_batch, active_indices = retrieve_DUMMY_batch(batch_size, np_ratio, poxel_collection, COL_IDX, [Ameans, Astds], [Bmeans, Bstds])
#         out = ac(pox_batch.to(device), active_batch.to(device), decoy_batch.to(device), np_ratio)
#         loss = criterion(out, y)
#         print(batch_idx, torch.mean(sigmoid(out[0:10])).item(), torch.mean(sigmoid(out[10:])).item(), loss.item())

        # print(loss)

training_loop(100, 32)
