In [1]:
import pickle
import os
import yaml
import sys
import torch
import torch_geometric
import numpy as np
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn import GATv2Conv 
from copy import deepcopy

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch.nn import Sequential as Seq, Linear, ReLU, SiLU

from math import log
from typing import Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Parameter
from torch_sparse import SparseTensor, matmul

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.typing import Adj, OptTensor

# from ..inits import glorot
from torch.nn.init import xavier_uniform_ as glorot


In [2]:
with open('config.yaml', 'r') as config_file:  
    config = yaml.safe_load(config_file) 
    
with open(config['protein_config_file'], 'r') as config_file:  
    protein_config = yaml.safe_load(config_file)   

In [None]:
interaction_corpus = pickle.load(open(config['interaction_voxel_graph_dir'] + "interaction_voxel_corpus.pkl", 'rb'))
og_corpus = deepcopy(interaction_corpus)

In [4]:
target_list = np.array(interaction_corpus[0])
# graph_list = np.array(interaction_corpus[1])

In [5]:
MAX_EDGE_WEIGHT = 15.286330223083496

interaction_corpus = deepcopy(og_corpus)

for g_idx in range(len(interaction_corpus[1])):
    updated_edge_index = deepcopy(interaction_corpus[1][g_idx].edge_index)
    updated_edge_attr = deepcopy(interaction_corpus[1][g_idx].edge_attr)
    updated_x = deepcopy(interaction_corpus[1][g_idx].x)
    
    for e_idx in range(updated_edge_attr.size(0)):
        updated_edge_attr[e_idx][-1] /= MAX_EDGE_WEIGHT
    
    self_edges = torch.arange(updated_x.size(0)).unsqueeze(0)
    self_edges = torch.cat((self_edges,self_edges),dim=0)
    
    updated_edge_index = torch.hstack((updated_edge_index, self_edges))
    updated_edge_attr = torch.hstack((updated_edge_attr, torch.zeros((updated_edge_attr.size(0), 1))))
    
    self_loop_features = torch.zeros((updated_x.size(0), updated_edge_attr.size(1)))
    self_loop_features[:,-1] = 1
    
    updated_edge_attr = torch.vstack((updated_edge_attr, self_loop_features))
    x_coords = updated_x[:,-3:]
    updated_x = updated_x[:,:-3]
    
    interaction_corpus[1][g_idx] = Data(x=updated_x, 
                                        x_coords=x_coords,
                                        edge_index=updated_edge_index,
                                        edge_attr=updated_edge_attr,
                                        y=interaction_corpus[1][g_idx].y)
    

In [6]:
# for g_idx in range(len(interaction_corpus[1])):
#     interaction_corpus[1][g_idx].edge_attr = torch.hstack((interaction_corpus[1][0].edge_attr,
#                                                        torch.zeros((len(interaction_corpus[1][0].edge_attr)),1)))

In [7]:
loader = DataLoader(interaction_corpus[1], batch_size=64, shuffle=True)

In [8]:
class_count = [0 for x in protein_config['interaction_labels']]

for batch in loader:
    for class_idx in batch.y:
        class_count[class_idx] += 1
        
example_count = sum(class_count)
class_max = max(class_count)
class_weights = torch.tensor([1-x/example_count for x in class_count])
# class_weights = torch.tensor([1/(x/class_max) for x in class_count])
class_weights = torch.tensor([class_max/x for x in class_count])
class_weights[4] = 60
        
print(class_weights)

tensor([61.5491,  1.3592,  1.6214,  1.0000, 60.0000, 52.4025,  9.3566,  9.0610,
        23.3785])


tensor([[-4.4420,  1.6140, 17.4390],
        [-2.4290,  0.4240, 19.1280],
        [-2.7730,  2.6940, 19.9650],
        [-4.2600,  4.8940, 18.4030],
        [-2.1390,  3.6580, 18.9540],
        [-2.0020,  1.3790, 19.9640],
        [-3.3460, -2.0310, 17.7640],
        [-1.6900, -0.8340, 19.0660],
        [-2.8090,  5.0290, 18.8410],
        [-0.6270,  0.1720, 17.1070]])
tensor([[-4.4645,  1.6157, 17.4276],
        [-2.4286,  0.4149, 19.1356],
        [-2.7753,  2.7048, 19.9799],
        [-4.2702,  4.9054, 18.4012],
        [-2.1355,  3.6733, 18.9581],
        [-1.9966,  1.3789, 19.9796],
        [-3.3550, -2.0502, 17.7590],
        [-1.6831, -0.8503, 19.0723],
        [-2.8105,  5.0453, 18.8429],
        [-0.6096,  0.1640, 17.0948]], grad_fn=<SliceBackward0>)


In [13]:
HIDDEN_DIMS = 64 

class EnConvNetwork(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.linear1 = torch.nn.Linear(NODE_DIMS, HIDDEN_DIMS)
        
        self.conv1 = EnConv(HIDDEN_DIMS, EDGE_DIMS, HIDDEN_DIMS, HIDDEN_DIMS) 
        self.conv2 = EnConv(HIDDEN_DIMS, EDGE_DIMS, HIDDEN_DIMS, HIDDEN_DIMS) 
        self.conv3 = EnConv(HIDDEN_DIMS, EDGE_DIMS, HIDDEN_DIMS, HIDDEN_DIMS) 
        self.conv4 = EnConv(HIDDEN_DIMS, EDGE_DIMS, HIDDEN_DIMS, HIDDEN_DIMS) 

        self.linear2 = torch.nn.Linear(HIDDEN_DIMS, len(INTERACTION_TYPES))

    def forward(self, data, show_coords=False):
        x, x_coords, edge_index, edge_attr = data.x, data.x_coords, data.edge_index, data.edge_attr
        
        x = self.linear1(x)

        h, xc = self.conv1(x, x, x_coords, edge_index, edge_attr)
        h = F.silu(h)

        h, xc = self.conv2(h, x, xc, edge_index, edge_attr)
        h = F.silu(h)

        h, xc = self.conv3(h, x, xc, edge_index, edge_attr)
        h = F.silu(h)
        
#         h, xc = self.conv4(h, x, xc, edge_index, edge_attr)
#         h = F.relu(h)
        
        if show_coords:
            print(xc)
        
        o = self.linear2(h)

        return o
    
model = EnConvNetwork()

for batch in loader:
    print(model(batch))
    break

tensor([[-0.1432,  0.0609, -0.0444,  ...,  0.1249, -0.0085, -0.1157],
        [-0.1129,  0.0359, -0.0104,  ...,  0.1307, -0.0060, -0.0835],
        [-0.0864, -0.0071,  0.0232,  ...,  0.1481,  0.0122, -0.0692],
        ...,
        [-0.2017,  0.1329, -0.0958,  ...,  0.0865, -0.0286, -0.1312],
        [-0.2146,  0.1457, -0.1016,  ...,  0.0854, -0.0195, -0.1452],
        [-0.2818,  0.2129, -0.1660,  ...,  0.0361, -0.0497, -0.1729]],
       grad_fn=<AddmmBackward0>)


In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = EnConvNetwork().to(device)
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss(weight=class_weights.to(device))

batch_idx = 0
for epoch in range(5000):
    avg_loss = []
    print("EPOCH %s" % epoch)
    
    for batch in loader:
        show_coords = False
        optimizer.zero_grad()
        batch = batch.to(device)
        dummy_indices = torch.where(batch.x[:,DUMMY_INDEX] == 1)
        
        out = model(batch)
        
        dummy_node_out = out[dummy_indices]
        loss = loss_function(dummy_node_out, batch.y)
        avg_loss.append(loss.item())
        loss.backward()
        optimizer.step()
        batch_idx += 1
    
    print("Average loss:", sum(avg_loss) / len(avg_loss))
    print(torch.argmax(dummy_node_out,dim=1)[:15])
    print(batch.y[:15])
    avg_loss = []

EPOCH 0
Average loss: 1.8782129840525252
tensor([1, 2, 6, 6, 6, 3, 6, 1, 8, 2, 2, 2, 7, 1, 1], device='cuda:0')
tensor([1, 2, 3, 3, 4, 3, 3, 2, 3, 2, 1, 1, 1, 3, 1], device='cuda:0')
EPOCH 1
Average loss: 1.5685814725894847
tensor([6, 3, 3, 3, 0, 1, 1, 1, 5, 1, 7, 1, 2, 1, 3], device='cuda:0')
tensor([4, 1, 3, 3, 3, 2, 1, 1, 3, 2, 1, 1, 2, 3, 3], device='cuda:0')
EPOCH 2
Average loss: 1.2927233609841184
tensor([3, 2, 3, 2, 3, 2, 8, 3, 7, 0, 0, 3, 7, 8, 0], device='cuda:0')
tensor([3, 1, 3, 2, 1, 2, 1, 3, 3, 0, 3, 3, 7, 1, 3], device='cuda:0')
EPOCH 3
Average loss: 1.1632984652797325
tensor([7, 2, 0, 8, 8, 2, 2, 8, 3, 5, 8, 3, 1, 1, 5], device='cuda:0')
tensor([1, 2, 3, 3, 3, 2, 2, 1, 3, 1, 3, 3, 1, 1, 3], device='cuda:0')
EPOCH 4
Average loss: 1.0811491144839596
tensor([1, 3, 2, 6, 2, 0, 6, 8, 0, 1, 0, 1, 6, 3, 0], device='cuda:0')
tensor([1, 3, 2, 6, 2, 0, 3, 3, 0, 1, 3, 1, 6, 3, 3], device='cuda:0')
EPOCH 5
Average loss: 1.0491222622638745
tensor([1, 5, 1, 3, 1, 3, 2, 2, 6, 5, 3, 1, 

Average loss: 0.5778929908914892
tensor([3, 2, 3, 3, 7, 1, 2, 2, 3, 3, 2, 3, 2, 1, 2], device='cuda:0')
tensor([3, 2, 3, 3, 7, 1, 2, 2, 3, 3, 2, 3, 2, 2, 2], device='cuda:0')
EPOCH 46
Average loss: 0.5741334613704071
tensor([6, 6, 6, 0, 2, 5, 6, 6, 4, 7, 3, 6, 2, 2, 3], device='cuda:0')
tensor([3, 6, 3, 2, 2, 3, 3, 3, 1, 1, 3, 6, 2, 2, 3], device='cuda:0')
EPOCH 47
Average loss: 0.5945694990890633
tensor([2, 5, 3, 1, 3, 3, 1, 3, 1, 3, 8, 3, 5, 2, 6], device='cuda:0')
tensor([2, 6, 3, 1, 3, 3, 1, 3, 2, 3, 2, 3, 1, 2, 3], device='cuda:0')
EPOCH 48
Average loss: 0.5831878510427
tensor([0, 8, 3, 2, 2, 3, 3, 3, 3, 1, 3, 2, 2, 3, 3], device='cuda:0')
tensor([3, 1, 3, 2, 2, 3, 3, 3, 1, 1, 3, 2, 2, 3, 3], device='cuda:0')
EPOCH 49
Average loss: 0.5739035049551412
tensor([1, 1, 6, 2, 3, 3, 0, 2, 0, 5, 6, 3, 3, 1, 2], device='cuda:0')
tensor([1, 1, 1, 2, 3, 3, 1, 2, 3, 3, 3, 3, 3, 1, 2], device='cuda:0')
EPOCH 50
Average loss: 0.5532921630237384
tensor([2, 6, 2, 6, 1, 8, 2, 3, 6, 1, 2, 8, 1, 2, 

Average loss: 0.4937234476399456
tensor([7, 1, 2, 0, 3, 1, 3, 0, 8, 3, 3, 0, 0, 1, 0], device='cuda:0')
tensor([1, 1, 2, 0, 3, 1, 3, 3, 3, 3, 3, 3, 0, 1, 3], device='cuda:0')
EPOCH 91
Average loss: 0.47092688688809636
tensor([1, 1, 2, 4, 2, 3, 1, 5, 3, 2, 2, 7, 2, 2, 3], device='cuda:0')
tensor([1, 3, 2, 1, 2, 3, 1, 3, 3, 2, 1, 1, 2, 2, 3], device='cuda:0')
EPOCH 92
Average loss: 0.45733596153483114
tensor([1, 5, 3, 7, 1, 3, 3, 2, 3, 1, 3, 6, 2, 2, 1], device='cuda:0')
tensor([2, 3, 1, 1, 1, 3, 3, 2, 3, 2, 3, 3, 2, 2, 2], device='cuda:0')
EPOCH 93
Average loss: 0.5663376140857319
tensor([3, 2, 0, 2, 8, 1, 1, 3, 1, 1, 1, 6, 1, 5, 2], device='cuda:0')
tensor([3, 2, 3, 3, 8, 2, 3, 3, 1, 0, 1, 3, 1, 5, 2], device='cuda:0')
EPOCH 94
Average loss: 0.4473488338683781
tensor([6, 2, 1, 6, 1, 2, 3, 3, 1, 3, 2, 2, 3, 2, 3], device='cuda:0')
tensor([3, 2, 1, 3, 1, 2, 3, 3, 1, 3, 2, 2, 3, 2, 3], device='cuda:0')
EPOCH 95
Average loss: 0.4129271182532331
tensor([3, 3, 1, 7, 2, 1, 5, 3, 3, 3, 3, 0, 7

KeyboardInterrupt: 

In [None]:
hidden = 2048
INTERACTION_TYPES = protein_config['interaction_labels']
NODE_DIMS = 41
EDGE_DIMS = 9
DUMMY_INDEX = protein_config['atom_labels'].index('DUMMY')


class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(NODE_DIM-3,hidden)
        
        self.conv1 = GCN2Conv(hidden, 0.2)
        self.conv2 = GCN2Conv(hidden, 0.2)
        self.conv3 = GCN2Conv(hidden, 0.2)

        self.linear2 = torch.nn.Linear(hidden, len(INTERACTION_TYPES))

    def forward(self, data):
        x, edge_index, edge_weights = data.x[:,:-3], data.edge_index, data.edge_attr[:,-1]

#         o = self.linear3(x)
        x = self.linear1(x)
    
        h = self.conv1(x, x, edge_index, edge_weights)
        h = F.relu(h)

#         h = self.conv2(h, x, edge_index, edge_weights)
#         h = F.relu(h)

#         h = self.conv2(h, x, edge_index, edge_weights)
#         h = F.relu(h)
        
        o = self.linear2(h)

#         return o,h
        return o

In [None]:
hidden = 1024
SELF_EDGE = torch.zeros(EDGE_DIMS, dtype=torch.float32)
SELF_EDGE[-1] = 1

class GAT(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(NODE_DIM-3,hidden)
        self.linear2 = torch.nn.Linear(hidden, len(INTERACTION_TYPES))
        
        self.conv1 = GATv2Conv(hidden, hidden, edge_dim=EDGE_DIMS, fill_value=SELF_EDGE)
        self.conv2 = GATv2Conv(hidden, hidden, edge_dim=EDGE_DIMS, fill_value=SELF_EDGE)
        self.conv3 = GATv2Conv(hidden, hidden, edge_dim=EDGE_DIMS, fill_value=SELF_EDGE)
        
    def forward(self, data):
        x, edge_index, edge_features = data.x[:,:-3], data.edge_index, data.edge_attr
        
        x = self.linear1(x)
        
        h = self.conv1(x, edge_index, edge_features)
        h = F.relu(h)
        h = self.conv2(x, edge_index, edge_features)
        h = F.relu(h)
        
        o = self.linear2(h)
        return o

In [None]:
a = torch.tensor([[1,2,3],
                  [25,0,0]])

b = torch.tensor([[4,5,6],
                 [0,25,0]])

pdist(a,b)
