In [2]:
import pickle
import os
import yaml
import sys
import torch
import torch_geometric
import numpy as np
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn import GATv2Conv 
from copy import deepcopy

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch.nn import Sequential as Seq, Linear, ReLU


In [3]:
with open('config.yaml', 'r') as config_file:  
    config = yaml.safe_load(config_file) 
    
with open(config['protein_config_file'], 'r') as config_file:  
    protein_config = yaml.safe_load(config_file)   

In [4]:
interaction_corpus = pickle.load(open(config['interaction_voxel_graph_dir'] + "interaction_voxel_corpus.pkl", 'rb'))
og_corpus = deepcopy(interaction_corpus)

In [5]:
target_list = np.array(interaction_corpus[0])
# graph_list = np.array(interaction_corpus[1])

In [6]:
MAX_EDGE_WEIGHT = 15.286330223083496

interaction_corpus = deepcopy(og_corpus)

for g_idx in range(len(interaction_corpus[1])):
    updated_edge_index = deepcopy(interaction_corpus[1][g_idx].edge_index)
    updated_edge_attr = deepcopy(interaction_corpus[1][g_idx].edge_attr)
    updated_x = deepcopy(interaction_corpus[1][g_idx].x)
    
    for e_idx in range(updated_edge_attr.size(0)):
        updated_edge_attr[e_idx][-1] /= MAX_EDGE_WEIGHT
    
    self_edges = torch.arange(updated_x.size(0)).unsqueeze(0)
    self_edges = torch.cat((self_edges,self_edges),dim=0)
    
    updated_edge_index = torch.hstack((updated_edge_index, self_edges))
    updated_edge_attr = torch.hstack((updated_edge_attr, torch.zeros((updated_edge_attr.size(0), 1))))
    
    self_loop_features = torch.zeros((updated_x.size(0), updated_edge_attr.size(1)))
    self_loop_features[:,-1] = 1
    
    updated_edge_attr = torch.vstack((updated_edge_attr, self_loop_features))
    
    interaction_corpus[1][g_idx] = Data(x=updated_x, 
                                        edge_index=updated_edge_index,
                                        edge_attr=updated_edge_attr,
                                        y=interaction_corpus[1][g_idx].y)
    

In [7]:
# for g_idx in range(len(interaction_corpus[1])):
#     interaction_corpus[1][g_idx].edge_attr = torch.hstack((interaction_corpus[1][0].edge_attr,
#                                                        torch.zeros((len(interaction_corpus[1][0].edge_attr)),1)))

In [8]:
loader = DataLoader(interaction_corpus[1], batch_size=64, shuffle=True)

In [9]:
class_count = [0 for x in protein_config['interaction_labels']]

for batch in loader:
    for class_idx in batch.y:
        class_count[class_idx] += 1
        
example_count = sum(class_count)
class_max = max(class_count)
class_weights = torch.tensor([1-x/example_count for x in class_count])
# class_weights = torch.tensor([1/(x/class_max) for x in class_count])
class_weights = torch.tensor([class_max/x for x in class_count])
class_weights[4] = 60
        
print(class_weights)

tensor([61.5491,  1.3592,  1.6214,  1.0000, 60.0000, 52.4025,  9.3566,  9.0610,
        23.3785])


In [15]:
pdist = torch.nn.PairwiseDistance(p=2)
hidden = 1024 
INTERACTION_TYPES = protein_config['interaction_labels']
NODE_DIMS = 41
EDGE_DIMS = 9
DUMMY_INDEX = protein_config['atom_labels'].index('DUMMY')

class EnConv(MessagePassing):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__(aggr='add')
        self.edge_conv = Seq(Linear(2057, hidden_channels),
                             ReLU(),
                             Linear(hidden_channels, hidden_channels))
        
        self.coord_weight = Seq(Linear(hidden_channels, hidden_channels),
                              ReLU(),
                              Linear(hidden_channels, 1))
        
        self.update_h = Seq(Linear(2*hidden_channels, hidden_channels),
                              ReLU(),
                              Linear(hidden_channels, out_channels))
        
        self.lin = Linear(in_channels-3, hidden_channels)
        
    def forward(self, x, edge_index, edge_attr):
        h = self.lin(x[:,:-3])
        h = torch.hstack((h, x[:,-3:]))
        
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = torch.unsqueeze(deg_inv_sqrt[row] * deg_inv_sqrt[col], dim=1)
        
        out = self.propagate(edge_index, x=h, edge_attr=edge_attr, norm=norm)
        m_i, coord_mod = (out[:,:-3], out[:,-3:])
        
        updated_h = self.update_h(torch.hstack((h[:,:-3],m_i)))
        updated_coords = x[:,-3:] + coord_mod
        return torch.hstack((updated_h, updated_coords))
        
        
    def message(self, x_i, x_j, edge_attr, norm):
        node_features = torch.cat((x_i[:,:-3],x_j[:,:-3]), dim=1)
        node_squared_distances = torch.unsqueeze(pdist(x_i[:,-3:], x_j[:,-3:])**2,dim=1)
        node_features = torch.hstack((node_features, node_squared_distances))
        
        edge_features = torch.hstack((edge_attr[:,:-2], edge_attr[:,-1:]))
        conv_features = torch.hstack((node_features,edge_features))
        
        x_i_coords = x_i[:,-3:]
        x_j_coords = x_j[:,-3:]
        
        m_ij = self.edge_conv(conv_features)
        
        coord_mod = norm * (x_i_coords - x_j_coords) * self.coord_weight(m_ij)
        return torch.hstack((m_ij, coord_mod))
         
        
# ec = EnConv(NODE_DIMS, 1024, len(INTERACTION_TYPES)) 
ec = EnConv(NODE_DIMS, 1024, 1024) 

for batch in loader:
    x, edge_index, edge_features = batch.x, batch.edge_index, batch.edge_attr
    print(ec(x, edge_index, edge_features).shape)
    break

torch.Size([768, 1027])


In [None]:
pdist = torch.nn.PairwiseDistance(p=2)
hidden = 1024 
INTERACTION_TYPES = protein_config['interaction_labels']
NODE_DIMS = 41
EDGE_DIMS = 9
DUMMY_INDEX = protein_config['atom_labels'].index('DUMMY')

class EnConv(MessagePassing):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__(aggr='add')
        self.edge_conv = Seq(Linear(2057, hidden_channels),
                             ReLU(),
                             Linear(hidden_channels, hidden_channels))
        
        self.coord_weight = Seq(Linear(hidden_channels, hidden_channels),
                              ReLU(),
                              Linear(hidden_channels, 1))
        
        self.update_h = Seq(Linear(2*hidden_channels, hidden_channels),
                              ReLU(),
                              Linear(hidden_channels, out_channels))
        
        self.lin = Linear(in_channels-3, hidden_channels)
        
    def forward(self, x, x_coords, edge_index, edge_attr):
        # TODO: Implement x_coords parameter functionality
        h = self.lin(x[:,:-3])
        h = torch.hstack((h, x[:,-3:]))
        
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = torch.unsqueeze(deg_inv_sqrt[row] * deg_inv_sqrt[col], dim=1)
        
        out = self.propagate(edge_index, x=h, edge_attr=edge_attr, norm=norm)
        m_i, coord_mod = (out[:,:-3], out[:,-3:])
        
        updated_h = self.update_h(torch.hstack((h[:,:-3],m_i)))
        updated_coords = x[:,-3:] + coord_mod
        return torch.hstack((updated_h, updated_coords))
        
        
    def message(self, x_i, x_j, edge_attr, norm):
        node_features = torch.cat((x_i[:,:-3],x_j[:,:-3]), dim=1)
        node_squared_distances = torch.unsqueeze(pdist(x_i[:,-3:], x_j[:,-3:])**2,dim=1)
        node_features = torch.hstack((node_features, node_squared_distances))
        
        edge_features = torch.hstack((edge_attr[:,:-2], edge_attr[:,-1:]))
        conv_features = torch.hstack((node_features,edge_features))
        
        x_i_coords = x_i[:,-3:]
        x_j_coords = x_j[:,-3:]
        
        m_ij = self.edge_conv(conv_features)
        
        coord_mod = norm * (x_i_coords - x_j_coords) * self.coord_weight(m_ij)
        return torch.hstack((m_ij, coord_mod))
         
        
# ec = EnConv(NODE_DIMS, 1024, len(INTERACTION_TYPES)) 
ec = EnConv(NODE_DIMS, 1024, 1024) 

for batch in loader:
    x, edge_index, edge_features = batch.x, batch.edge_index, batch.edge_attr
    print(ec(x, edge_index, edge_features).shape)
    break

In [None]:
a = torch.tensor([[1,2,3], [4,8,10], [3,8,20], [5,6,7]])
b = torch.tensor([[1],[2],[3],[5]])

print(a.shape)
print(b.shape)

a*b

In [None]:
hidden = 2048
INTERACTION_TYPES = protein_config['interaction_labels']
NODE_DIMS = 41
EDGE_DIMS = 9
DUMMY_INDEX = protein_config['atom_labels'].index('DUMMY')


class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(NODE_DIM-3,hidden)
        
        self.conv1 = GCN2Conv(hidden, 0.2)
        self.conv2 = GCN2Conv(hidden, 0.2)
        self.conv3 = GCN2Conv(hidden, 0.2)

        self.linear2 = torch.nn.Linear(hidden, len(INTERACTION_TYPES))

    def forward(self, data):
        x, edge_index, edge_weights = data.x[:,:-3], data.edge_index, data.edge_attr[:,-1]

#         o = self.linear3(x)
        x = self.linear1(x)
    
        h = self.conv1(x, x, edge_index, edge_weights)
        h = F.relu(h)

#         h = self.conv2(h, x, edge_index, edge_weights)
#         h = F.relu(h)

#         h = self.conv2(h, x, edge_index, edge_weights)
#         h = F.relu(h)
        
        o = self.linear2(h)

#         return o,h
        return o

In [None]:
hidden = 1024
SELF_EDGE = torch.zeros(EDGE_DIMS, dtype=torch.float32)
SELF_EDGE[-1] = 1

class GAT(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(NODE_DIM-3,hidden)
        self.linear2 = torch.nn.Linear(hidden, len(INTERACTION_TYPES))
        
        self.conv1 = GATv2Conv(hidden, hidden, edge_dim=EDGE_DIMS, fill_value=SELF_EDGE)
        self.conv2 = GATv2Conv(hidden, hidden, edge_dim=EDGE_DIMS, fill_value=SELF_EDGE)
        self.conv3 = GATv2Conv(hidden, hidden, edge_dim=EDGE_DIMS, fill_value=SELF_EDGE)
        
    def forward(self, data):
        x, edge_index, edge_features = data.x[:,:-3], data.edge_index, data.edge_attr
        
        x = self.linear1(x)
        
        h = self.conv1(x, edge_index, edge_features)
        h = F.relu(h)
        h = self.conv2(x, edge_index, edge_features)
        h = F.relu(h)
        
        o = self.linear2(h)
        return o

In [None]:
a = torch.tensor([[1,2,3],
                  [25,0,0]])

b = torch.tensor([[4,5,6],
                 [0,25,0]])

pdist(a,b)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss(weight=class_weights.to(device))

batch_idx = 0
for epoch in range(5000):
    avg_loss = []
    print("EPOCH %s" % epoch)
    
    for batch in loader:
        optimizer.zero_grad()
        batch = batch.to(device)
        dummy_indices = torch.where(batch.x[:,DUMMY_INDEX] == 1)
        out = model(batch)
        dummy_node_out = out[dummy_indices]
        loss = loss_function(dummy_node_out, batch.y)
        avg_loss.append(loss.item())
        loss.backward()
        optimizer.step()
        batch_idx += 1
    
    print("Average loss:", sum(avg_loss) / len(avg_loss))
    print(torch.argmax(dummy_node_out,dim=1)[:15])
    print(batch.y[:15])
    avg_loss = []