In [2]:
import torch
import torch.optim as optim
import torch_geometric
import torch_geometric.transforms as T
from torch.nn.functional import relu, sigmoid
from torch.nn import Linear, Module, Dropout, MSELoss, CrossEntropyLoss, BatchNorm1d

from torch_geometric.nn import GCNConv, GATConv, GraphNorm, TransformerConv
from torch_geometric.data import Data
from torch_sparse import SparseTensor

import pandas as pd
import numpy as np
import random
import optuna

seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
device = 0
device = torch.device("cuda:{}".format(device) if torch.cuda.is_available() else "cpu")

from tqdm import tqdm

In [4]:
class MultiHeadAttention(Module):
    def __init__(self, hidden_dim, num_heads, dropout):
        super(MultiHeadAttention, self).__init__()

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert hidden_dim % num_heads == 0

        self.query_linear = Linear(hidden_dim, hidden_dim)
        self.key_linear = Linear(hidden_dim, hidden_dim)
        self.value_linear = Linear(hidden_dim, hidden_dim)

        self.output_linear = Linear(hidden_dim, hidden_dim)
        self.dropout = Dropout(dropout)

        self.scale = torch.sqrt(torch.FloatTensor([hidden_dim // num_heads])).cuda()

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]

        Q = self.query_linear(query)
        K = self.key_linear(key)
        V = self.value_linear(value)

        Q = Q.view(batch_size, self.num_heads, -1, self.hidden_dim // self.num_heads)
        K = K.view(batch_size, self.num_heads, -1, self.hidden_dim // self.num_heads)
        V = V.view(batch_size, self.num_heads, -1, self.hidden_dim // self.num_heads)

        energy = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)

        attention = self.dropout(F.softmax(energy, dim=-1))

        weighted_matrix = torch.matmul(attention, V)

        weighted_matrix = weighted_matrix.permute(0, 2, 1, 3).contiguous()
        weighted_matrix = weighted_matrix.view(batch_size, -1, self.hidden_dim)

        output = self.dropout(self.output_linear(weighted_matrix))

        return output

In [6]:
class DrugEncoder(Module):
    def __init__(self, num_features_xd, dim, output_dim):
        super(DrugEncoder, self).__init__()
        self.conv1 = GCNConv(num_features_xd, dim)
        self.conv2 = GCNConv(dim, dim)
        self.fc1 = Linear(dim, output_dim)
        self.relu = ReLU()
        self.norm = GraphNorm()
        self.global_pool = GlobalMeanPool()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.norm(self.relu(self.conv1(x, edge_index)))
        x = self.norm(self.relu(self.conv2(x, edge_index)))
        
        x = self.global_pool(x) 
        x = self.fc1(x)
        return x

In [7]:
class ProteinEncoder(Module):
    def __init__(self, num_features_xd, dim, output_dim):
        super(ProteinEncoder, self).__init__()
        self.conv1 = GCNConv(num_features_xd, dim)
        self.conv2 = GCNConv(dim, dim)
        self.fc1 = Linear(dim, output_dim)
        self.relu = ReLU()
        self.norm = GraphNorm()
        self.global_pool = GlobalMeanPool()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.norm(self.relu(self.conv1(x, edge_index)))
        x = self.norm(self.relu(self.conv2(x, edge_index)))
        
        x = self.global_pool(x) 
        x = self.fc1(x)
        return x

In [9]:
class DTIPredictor(Module):
    def __init__(self, drug_encoder, protein_encoder, hidden_dim, num_heads):
        super(DTIPredictor, self).__init()
        self.drug_encoder = DrugEncoder()
        self.protein_encoder = ProteinEncoder()
        self.attention = MultiHeadAttention(hidden_dim, num_heads, dropout)

        self.fc1_combined = Linear(hidden_dim * 2, hidden_dim)
        self.fc_output = Linear(hidden_dim, 1)

    def forward(self, drug_data, protein_data):
        x_drug = self.drug_encoder(drug_data)
        x_protein = self.protein_encoder(protein_data)
        attention_output = self.attention(x_drug, x_protein, x_protein)

        feature = torch.cat((x_drug, attention_output, x_protein), dim=1)
        prediction = self.fc_output(feature)

        return sigmoid(prediction)