In [1]:
import pickle
import re
import numpy as np
import sys
import os
import glob
import networkx as nx
import torch
import torch_geometric
import random
import yaml
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import remove_isolated_nodes
from torch import nn
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn import SAGPooling
from torch_geometric.nn import MLP
from torch_geometric.nn import AttentiveFP
from torch_geometric.nn.aggr import AttentionalAggregation
from copy import deepcopy 
from torch_geometric.nn import GATConv, MessagePassing, global_add_pool

import torch
import torch.nn.functional as F

In [2]:
with open('/xdisk/twheeler/jgaiser/deepvs/deepvs/data/protein_config.yaml', 'r') as config_file:
    protein_config = yaml.safe_load(config_file) 

PROTEIN_ATOM_LABELS = protein_config['atom_labels']
PROTEIN_EDGE_LABELS = protein_config['edge_labels']
INTERACTION_LABELS = protein_config['interaction_labels']

dummy_index = PROTEIN_ATOM_LABELS.index('DUMMY')
voxel_edge_index = PROTEIN_EDGE_LABELS.index('voxel')
model_dir = '/xdisk/twheeler/jgaiser/deepvs/deepvs/data/models/pocket_embed/'

pocket_graph_dir = '/xdisk/twheeler/jgaiser/deepvs/deepvs/data/graph_data/pockets/'
mol_graph_dir = '/xdisk/twheeler/jgaiser/deepvs/deepvs/data/graph_data/molecules/'
positive_voxel_graph_dir = '/xdisk/twheeler/jgaiser/deepvs/deepvs/data/graph_data/voxel_positives/'
pocket_embed_dir = '/xdisk/twheeler/jgaiser/deepvs/deepvs/data/pocket_embeds/'

pocket_file_structure = pocket_graph_dir + "%s_pocket_graph.pkl"
positive_voxel_graph_file_structure = positive_voxel_graph_dir + "%s_voxel_positive_graphs.pkl"
mol_file_structure = mol_graph_dir + "%s_mol_graph.pkl"
pocket_embed_file_structure = pocket_embed_dir + "%s_pocket_embed.pkl"

holdout_complexes = ["3gdt", "3g1v", "3w07", "3g1d", "1loq", "3wjw", "2zz1", "2zz2", "1km3", "1x1z", 
                     "6cbg", "5j7q", "6cbf", "4wrb", "6b1k", "5hvs", "5hvt", "3rf5", "3rf4", "1mfi", 
                     "5efh", "6csq", "5efj", "6csr", "6css", "6csp", "5een", "5ef7", "5eek", "5eei",
                     "3ozt", "3u81", "4p58", "5k03", "3ozr", "3ozs", "3oe5", "3oe4", "3hvi", "3hvj",
                     "3g2y", "3g2z", "3g30", "3g31", "3g34", "3g32", "4de2", "3g35", "4de0", "4de1",
                     "2exm", "4i3z", "1e1v", "5jq5", "1jsv", "1e1x", "4bcp", "4eor", "1b38", "1pxp", "2xnb", "4bco", "4bcm", "1pxn", "4bcn", "1h1s", "4bck", "2fvd", "1pxo", "2xmy",
                     "4xoe", "5fs5", "1uwf", "4att", "4av4", "4av5", "4avh", "4avj", "4avi", "4auj", "4x50", "4lov", "4x5r", "4buq", "4x5p", "4css", "4xoc", "4cst", "4xo8", "4x5q",
                     "1gpk", "3zv7", "1gpn", "5bwc", "5nau", "5nap", "1h23", "1h22", "1e66", "4m0e", "4m0f", "2ha3", "2whp", "2ha6", "2ha2", "1n5r", "4arb", "4ara", "5ehq", "1q84",
                     "2z1w", "3rr4", "1s38", "1q65", "4q4q", "4q4p", "4q4r", "4kwo", "1r5y", "4leq", "4lbu", "1f3e", "4pum", "4q4s", "3gc5", "2qzr", "4q4o", "3gc4", "5jxq", "3ge7"]

PDB_IDS = []

for item in glob.glob(mol_graph_dir + "*"):
    pdb_id = item.split('/')[-1].split('_')[0]
    
    if pdb_id in holdout_complexes:
        continue
        
    PDB_IDS.append(pdb_id)

PDB_IDS = np.array(sorted(PDB_IDS))

class_count = torch.zeros(len(INTERACTION_LABELS))

for s_id in PDB_IDS:
    g = pickle.load(open(mol_file_structure % s_id, 'rb'))
    class_count = torch.add(class_count, torch.sum(g.y, dim=0))
        
interaction_class_weights = (1 / (class_count / torch.max(class_count)))
        

In [3]:
def fetch_training_batch(pdb_ids, active_batch_size, decoy_batch_size):
    active_sample_ids = np.random.choice(pdb_ids, active_batch_size)
    decoy_sample_ids = np.random.choice(pdb_ids, decoy_batch_size)
    
    pocket_graphs = [] 
    active_graphs = []
    decoy_graphs = []
    
    for s_id in active_sample_ids:
        pocket_graphs.append(pickle.load(open(pocket_embed_file_structure % s_id, 'rb')))
        active_graphs.append(pickle.load(open(mol_file_structure % s_id, 'rb')))
        
    for s_id in decoy_sample_ids:
        decoy_graphs.append(pickle.load(open(mol_file_structure % s_id, 'rb')))
    
    pocket_loader = DataLoader(pocket_graphs, batch_size=active_batch_size, shuffle=False)
    active_mol_loader = DataLoader(active_graphs, batch_size=active_batch_size, shuffle=False)
    decoy_mol_loader = DataLoader(decoy_graphs, batch_size=decoy_batch_size, shuffle=False)
    
    return next(iter(pocket_loader)), next(iter(active_mol_loader)), next(iter(decoy_mol_loader))

In [4]:
# MOLECULE MODEL
mol_hidden = 512 

class AtomicAttentiveFP(AttentiveFP):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.atom_classifier = nn.Linear(kwargs['hidden_channels'], 9)
        
    def forward(self, x, edge_index, edge_attr, batch):
        """"""
        # Atom Embedding:
        x = F.leaky_relu_(self.lin1(x))

        h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr))
        h = F.dropout(h, p=self.dropout, training=self.training)
        x = self.atom_grus[0](h, x).relu_()

        for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]):
            h = F.elu_(conv(x, edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            x = gru(h, x).relu_()

        # Molecule Embedding:
        row = torch.arange(batch.size(0), device=batch.device)
        edge_index = torch.stack([row, batch], dim=0)

        out = global_add_pool(x, batch).relu_()
        
        for t in range(self.num_timesteps):
            h = F.elu_(self.mol_conv((x, out), edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            out = self.mol_gru(h, out).relu_()

        # Predictor:
        out = F.dropout(out, p=self.dropout, training=self.training)
        return self.atom_classifier(x), self.lin2(out)

class MoleculePool(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = AtomicAttentiveFP(in_channels=52, 
                                 hidden_channels=mol_hidden, 
                                 out_channels=mol_hidden,
                                 edge_dim=10,
                                 num_layers=5,
                                 num_timesteps=5,
                                 dropout=0.0)
    
    def forward(self, data):
        x, edge_index, edge_weights, batch = data.x, data.edge_index, data.edge_attr, data.batch
        h = self.conv1(x, edge_index, edge_weights, batch)
        return h

In [5]:
hidden = 512 

class PoxelGCN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = GCN2Conv(hidden, 0.1)
        self.conv2 = GCN2Conv(hidden, 0.1)

        self.pool1 = SAGPooling(hidden)

        self.conv3 = GCN2Conv(hidden, 0.2)
        self.conv4 = GCN2Conv(hidden, 0.2)

        self.pool2 = SAGPooling(hidden)

        gate_nn = MLP([hidden, 1], act='relu')
        nn = MLP([hidden, hidden], act='relu')
        self.global_pool = AttentionalAggregation(gate_nn, nn)

    def forward(self, data):
        x, edge_index, edge_weights, batch = data.x, data.edge_index, data.edge_attr, data.batch
        h = self.conv1(x, x, edge_index, edge_weights)
        h = F.relu(h)

        h = self.conv2(h, x, edge_index, edge_weights)
        h = F.relu(h)

        h, edge_index, edge_weights, batch, _, _ = self.pool1(h, edge_index, edge_weights, batch)

        h = self.conv3(h, x, edge_index, edge_weights)
        h = F.relu(h)

        h, edge_index, edge_weights, batch, _, _ = self.pool2(h, edge_index, edge_weights, batch)

        h = self.conv4(h, x, edge_index, edge_weights)
        h = F.relu(h)

        h = self.global_pool(h, index=batch)
        return h


In [6]:
# ACTIVE CLASSIFIER
#@title Classifier

class ActiveClassifier(torch.nn.Module):
    def __init__(self, poxel_model, molecule_model):
        super(ActiveClassifier, self).__init__()
        self.pox_pooler = poxel_model()
        self.mol_pooler = molecule_model()

        self.linear1 = nn.Linear(1024, 512)
        self.linear2 = nn.Linear(512, 512)
        self.linear3 = nn.Linear(512, 1)

        self.relu = nn.ReLU()

    def forward(self, pocket_batch, active_batch, decoy_batch):
        poxel_embeds = self.pox_pooler(pocket_batch)
        
        active_preds, active_embeds = self.mol_pooler(active_batch)
        decoy_preds, decoy_embeds = self.mol_pooler(decoy_batch)
        mol_atom_preds = torch.vstack((active_preds, decoy_preds))
        
        
        poxel_actives = torch.hstack((poxel_embeds, active_embeds))
        poxel_decoys = torch.hstack((torch.cat([poxel_embeds]*len(decoy_embeds), dim=0), 
                                     decoy_embeds.repeat_interleave(poxel_embeds.size(0), dim=0)))
        
        all_embeds = torch.vstack((poxel_actives, poxel_decoys))

        x = self.linear1(all_embeds) 
        x = self.relu(x)
        
        x = self.linear2(x) 
        x = self.relu(x)
        
        x = self.linear3(x) 
        return x, mol_atom_preds

In [7]:
ac_model = ActiveClassifier(PoxelGCN, MoleculePool)
poxel_batch, active_batch, decoy_batch = fetch_training_batch(PDB_IDS, 32, 32)

out1, out2 = ac_model(poxel_batch, active_batch, decoy_batch)

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [11]:
# TRAINING LOOP
BATCH_SIZE=32
sigmoid = nn.Sigmoid()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

criterion = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([100])).to(device)
mol_criterion = nn.BCEWithLogitsLoss(pos_weight=interaction_class_weights).to(device)

ac = ActiveClassifier(PoxelGCN, MoleculePool).to(device)
optimizer = torch.optim.Adam(ac.parameters(), lr=1e-4)

def training_loop(epochs, batch_size):
    for epoch in range(epochs):
        print("EPOCH %s" % epoch)
        
        loss_history = []

        for batch_idx in range(int(len(PDB_IDS) / BATCH_SIZE)):
            optimizer.zero_grad() 
           
            poxel_batch, active_batch, decoy_batch = fetch_training_batch(PDB_IDS, BATCH_SIZE, 10)
            
            mol_y = torch.vstack((active_batch.y, decoy_batch.y))
            
            mol_train_mask = torch.where(torch.sum(mol_y, dim=1) > 0)
            
            poxel_batch = poxel_batch.to(device)
            active_batch = active_batch.to(device)
            decoy_batch = decoy_batch.to(device)

            out, atom_out = ac(poxel_batch, active_batch, decoy_batch)

            y = torch.zeros(out.size(0))
            y[:batch_size] = 1
            y = torch.unsqueeze(y, dim=1).to(device)

            classification_loss = criterion(out, y)
            mol_loss = mol_criterion(atom_out[mol_train_mask], mol_y[mol_train_mask].to(device))
            
#             loss = classification_loss + 0.5*mol_loss
            loss = classification_loss 
            
            loss_history.append(loss.item())
            loss.backward()
            optimizer.step()

            if batch_idx % 20 == 0:
                print(sum(loss_history) / len(loss_history))
                print(classification_loss.item(), mol_loss.item())
                print('active:', [float("%.3f" % x) for x in torch.sigmoid(out[:7].squeeze()).tolist()])
                print('decoy:', [float("%.3f" % x) for x in torch.sigmoid(out[-7:].squeeze()).tolist()])
                print('')
                loss_history=[]
    
#       for batch_idx in range(int(POXEL_COLLECTION_SIZE / batch_size)):
#         y = torch.hstack([torch.ones(batch_size), torch.zeros(batch_size*np_ratio)]).unsqueeze(dim=1).to(device) 
      
#         pox_batch, active_batch, decoy_batch, active_indices = retrieve_training_batch(batch_size, np_ratio, poxel_collection, col_idx)
    
#         # pox_batch, active_batch, decoy_batch, active_indices = retrieve_DUMMY_batch(batch_size, np_ratio, poxel_collection, COL_IDX, [Ameans, Astds], [Bmeans, Bstds])
#         out = ac(pox_batch.to(device), active_batch.to(device), decoy_batch.to(device), np_ratio)
#         loss = criterion(out, y)
#         print(batch_idx, torch.mean(sigmoid(out[0:10])).item(), torch.mean(sigmoid(out[10:])).item(), loss.item())

        # print(loss)

training_loop(100, 32)


NVIDIA A100 80GB PCIe with CUDA capability sm_80 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70.
If you want to use the NVIDIA A100 80GB PCIe GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



EPOCH 0


RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.