# Drug - Protein Interaction Prediction

In [1]:
# Importing the libraries
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import scipy.sparse as sp
from sklearn.model_selection import KFold

In [2]:
from sklearn.metrics import accuracy_score

In [3]:
torch.cuda.is_available()

True

In [4]:
# Building the Heterogeneous Network
def hetero_generate(drug_drug, drug_simi, prot_prot, prot_simi, A, w1, w2):
    drug_fused = w1*drug_simi + drug_drug
    prot_fused = w2*prot_simi + prot_prot

    drug_diag = pd.DataFrame(index=drug_fused.index, columns=drug_fused.columns, dtype=float)
    prot_diag = pd.DataFrame(index=prot_fused.index, columns=prot_fused.columns, dtype=float)
    sum_drug = pd.DataFrame(torch.diag(Tensor(drug_fused.sum(axis=1))).numpy())
    sum_prot = pd.DataFrame(torch.diag(Tensor(prot_fused.sum(axis=1))).numpy())
 
    drug_inv = sum_drug.apply(lambda x: 1 / np.sqrt(x))
    prot_inv = sum_prot.apply(lambda x: 1 / np.sqrt(x))
    drug_inv.replace([np.inf, -np.inf], 0, inplace=True)
    prot_inv.replace([np.inf, -np.inf], 0, inplace=True)

    drug_fused.set_axis([i for i in range(708)], axis=1, inplace=True)
    U = (drug_inv.dot(drug_fused)).dot(drug_inv)
    prot_fused.set_axis([i for i in range(1512)], axis=1, inplace=True)
    P = (prot_inv.dot(prot_fused)).dot(prot_inv)

    A.set_axis([i for i in range(708)], axis=0, inplace=True)
    A.set_axis([i for i in range(708,2220)], axis=1, inplace=True)

    drug_emb = pd.concat([U, A], axis=1)

    At = A.transpose()
    
    At.set_axis([i for i in range(708,2220)], axis=0, inplace=True)
    At.set_axis([i for i in range(708)], axis=1, inplace=True)
    
    P.set_axis([i for i in range(708,2220)], axis=0, inplace=True)
    P.set_axis([i for i in range(708,2220)], axis=1, inplace=True)
    
    prot_emb = pd.concat([At,P], axis=1)
    het_net = pd.concat([drug_emb,prot_emb], axis=0)

    drug_drug.set_axis([i for i in range(708)], axis=1, inplace=True)
    prot_prot.set_axis([i for i in range(708,2220)], axis=0, inplace=True)
    prot_prot.set_axis([i for i in range(708,2220)], axis=1, inplace=True)

    adj1 = pd.concat([drug_drug, A], axis=1)
    adj2 = pd.concat([At, prot_prot], axis=1)
    adj = pd.concat([adj1, adj2], axis=0)

    adj_coo = sp.coo_matrix(adj)
    row_indices = adj_coo.row
    col_indices = adj_coo.col
    edge_index = torch.tensor([row_indices, col_indices], dtype=torch.long)
    het_ten = torch.tensor(het_net.values, dtype=torch.float)

    return het_ten, edge_index

In [5]:
# Defining the model architecture
class GCN(torch.nn.Module):
    def __init__(self, in_channels, out_channels, weight_param,  dropout=0.6):
        super().__init__()
        self.weight_param = torch.nn.Parameter(torch.Tensor(weight_param,weight_param))
        torch.nn.init.xavier_uniform_(self.weight_param)
        self.conv1 = GCNConv(in_channels, 1024)
        self.conv2 = GCNConv(1024 , 512)
        self.conv3 = GCNConv(512, 256)
        self.conv4 = GCNConv(256 ,out_channels)
        self.act = torch.nn.ReLU()
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.dropout(x, p=0.6)
        x = self.act(x)
        x = self.conv2(x, edge_index)   
        x = F.dropout(x, p=0.6)
        x = self.act(x)
        x = self.conv3(x, edge_index)
        x = F.dropout(x, p=0.6)
        x = self.act(x)
        x = self.conv4(x, edge_index)
        x = x.sigmoid()
        U = x[:708,:]
        V = x[708:,:]
        inter = torch.matmul(torch.matmul(U,self.weight_param),V.transpose(0,1))
        return inter.sigmoid()

In [6]:
# Defining the loss function
class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, pos_weight=None, num_drugs=708, num_prots=1512):
        super(WeightedCrossEntropyLoss, self).__init__()
        self.pos_weight = pos_weight
        self.num_drugs = num_drugs
        self.num_prots = num_prots

    def forward(self, predictions, targets):
        bce_loss = nn.BCELoss(weight=self.pos_weight, reduction='sum')
        loss = (1/(self.num_drugs*self.num_prots))*bce_loss(predictions, targets)
        return loss

In [7]:
# Training the model
def train(model, hetero_net, edge_index, val_data, optimizer, criterion, gpu, epoch = 4000):
    if gpu=='gpu':
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = "cpu"
    model.to(device)
    hetero_net = hetero_net.to(device)
    edge_index = edge_index.to(device)  
    val_data = val_data.to(device)
    for e in range(1,epoch+1):
        model.train()
        optimizer.zero_grad()
        out = model(hetero_net, edge_index)
        loss = criterion(out, val_data)
        loss.backward()
        optimizer.step()
        if e % 100 == 0:
            print('Epoch {}, Loss {}'.format(e,loss))
    return model, out

In [8]:
At = pd.read_csv('dti/mat_prot_drug.csv')
A = At.transpose()

In [9]:
drug_simi = pd.read_csv('dti/drug_simi.csv')
drug_simi = drug_simi.drop("Column709",axis = 1)

In [10]:
prot_simi = pd.read_csv('dti/prot_simi.csv')

In [11]:
drug_drug = pd.read_csv('dti/mat_drug_drug.csv')
prot_prot = pd.read_csv('dti/mat_prot_prot.csv')

In [12]:
w1 = 5
w2 = 5

In [None]:
het_ten, edge_index = hetero_generate(drug_drug, drug_simi, prot_prot, prot_simi, A, w1, w2)

In [14]:
data = Data(x=het_ten, edge_index=edge_index)

In [15]:
weights = pd.DataFrame(np.ones(shape = (708,1512)))
pos_mask = A == 1
pos_sam = A.sum().sum()
neg_sam = (708*1512) - pos_sam
weights[pos_mask] = pos_sam/neg_sam
wei_ten = Tensor(weights.values).to('cuda')

In [16]:
A_ten = torch.tensor(A.values, dtype=torch.float32)

In [17]:
model = GCN(2220, 64, 64, 0.6)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
crit = WeightedCrossEntropyLoss(wei_ten)
model, out = train(model, data.x, data.edge_index, A_ten, optimizer, crit, 'gpu', 4000) 

Epoch 100, Loss 0.00938374176621437
Epoch 200, Loss 0.00849129632115364
Epoch 300, Loss 0.007321196608245373
Epoch 400, Loss 0.006419534794986248
Epoch 500, Loss 0.005697275046259165
Epoch 600, Loss 0.005144158378243446
Epoch 700, Loss 0.004706823732703924
Epoch 800, Loss 0.004257701802998781
Epoch 900, Loss 0.0039937859401106834
Epoch 1000, Loss 0.0037149330601096153
Epoch 1100, Loss 0.00352885271422565
Epoch 1200, Loss 0.003233392257243395
Epoch 1300, Loss 0.0031085724476724863
Epoch 1400, Loss 0.0029646665789186954
Epoch 1500, Loss 0.0028533563017845154
Epoch 1600, Loss 0.002734668320044875
Epoch 1700, Loss 0.002532227896153927
Epoch 1800, Loss 0.002555325161665678
Epoch 1900, Loss 0.0024026907049119473
Epoch 2000, Loss 0.002183341421186924
Epoch 2100, Loss 0.002247321419417858
Epoch 2200, Loss 0.002095705596730113
Epoch 2300, Loss 0.0020342175848782063
Epoch 2400, Loss 0.0020327477250248194
Epoch 2500, Loss 0.0018765494460240006
Epoch 2600, Loss 0.0018631231505423784
Epoch 2700, Lo