In [1]:
import torch
import torch.optim as optim
import torch_geometric
from torch.nn.functional import relu, sigmoid, softmax, mse_loss
from torch.nn import Linear, Module, Dropout, MSELoss, CrossEntropyLoss, BatchNorm1d

from tqdm import tqdm

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, GATv2Conv, SchNet

import pandas as pd
import numpy as np

import os
import pickle
import gzip

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
device = 0
device = torch.device("cuda:{}".format(device) if torch.cuda.is_available() else "cpu")

In [2]:
class MultiHeadAttention(Module):
    def __init__(self, hidden_dim, num_heads, dropout):
        super(MultiHeadAttention, self).__init__()

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert hidden_dim % num_heads == 0

        self.query_linear = Linear(hidden_dim, hidden_dim)
        self.key_linear = Linear(hidden_dim, hidden_dim)
        self.value_linear = Linear(hidden_dim, hidden_dim)

        self.output_linear = Linear(hidden_dim, hidden_dim)
        self.dropout = Dropout(dropout)

        self.scale = torch.sqrt(torch.FloatTensor([hidden_dim // num_heads])).to(device)

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]

        Q = self.query_linear(query)
        K = self.key_linear(key)
        V = self.value_linear(value)

        Q = Q.view(batch_size, self.num_heads, -1, self.hidden_dim // self.num_heads)
        K = K.view(batch_size, self.num_heads, -1, self.hidden_dim // self.num_heads)
        V = V.view(batch_size, self.num_heads, -1, self.hidden_dim // self.num_heads)
    
        energy = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)

        attention = self.dropout(softmax(energy, dim=-1))

        weighted_matrix = torch.matmul(attention, V)

        weighted_matrix = weighted_matrix.permute(0, 2, 1, 3).contiguous()
        weighted_matrix = weighted_matrix.view(batch_size, -1, self.hidden_dim)
#         weighted_matrix = weighted_matrix.mean(dim=1)
#         weighted_matrix = torch.max(weighted_matrix, dim=1)
        weighted_matrix = weighted_matrix[:, 0, :]

        output = self.output_linear(weighted_matrix)

        return output

In [3]:
class DTIPredictor(Module):
    def __init__(self, hidden_dim=1, num_heads=1, dropout=0.2):
        super().__init__()
        self.drug_encoder = SchNet()
        self.protein_encoder = SchNet()
        self.attention = MultiHeadAttention(hidden_dim, num_heads, dropout)
        self.fc_output = Linear(hidden_dim*3, 1)

    def forward(self, drug_data, protein_data):
        x_drug = self.drug_encoder(drug_data.atom_type.int(), drug_data.pos.float(), drug_data.batch)
        x_protein = self.protein_encoder(protein_data.atom_type.int(), protein_data.pos.float(), protein_data.batch)
        attention_output = self.attention(x_drug, x_protein, x_protein)

        feature = torch.cat((x_drug, attention_output, x_protein), dim=1)
        prediction = self.fc_output(feature)

        return (prediction).squeeze(1)

In [4]:
train = pd.read_csv('kiba/train.csv', index_col=0)
val = pd.read_csv('kiba/val.csv', index_col=0)
test = pd.read_csv('kiba/test.csv', index_col=0)

In [5]:
print('Train dim:', train.shape)
print('val dim:', val.shape)
print('test dim:', test.shape)

Train dim: (3449, 3)
val dim: (494, 3)
test dim: (973, 3)


In [6]:
with gzip.open('drug.pkl.gz', 'rb') as f:
    drug = pickle.load(f)

def get_drug_dataloader(drugs, batch_size=100):
    dataset = [drug[i] for i in drugs]
    return DataLoader(dataset, batch_size=batch_size)

def get_protein_dataloader(proteins, batch_size=100):
    dataset = [torch.load('protein_graphs/{}.pt'.format(i)) for i in proteins]
    return DataLoader(dataset, batch_size=batch_size)

In [7]:
batch_size = 10
drug_train_loader = get_drug_dataloader(train['Drug'], batch_size)
drug_val_loader = get_drug_dataloader(val['Drug'], batch_size)
drug_test_loader = get_drug_dataloader(test['Drug'], batch_size)

protein_train_loader = get_protein_dataloader(train['Target_ID'], batch_size)
protein_val_loader = get_protein_dataloader(val['Target_ID'], batch_size)
protein_test_loader = get_protein_dataloader(test['Target_ID'], batch_size)

train_y = DataLoader(torch.Tensor(train['Y'].values).float(), batch_size=batch_size)
val_y = DataLoader(torch.Tensor(val['Y'].values).float(), batch_size=batch_size)
test_y = DataLoader(torch.Tensor(test['Y'].values).float(), batch_size=batch_size)



In [None]:
model = DTIPredictor().to(device)
criterion = MSELoss().to(device)
optimizer = getattr(torch.optim, "Adam")(model.parameters(), lr=0.01,)

train_losses = []
train_accs = []
val_losses = []
val_accs = []

for epoch in tqdm(range(10)):

    model.train()
    total_loss = 0
    for drug, protein, true_y in zip(drug_train_loader, protein_train_loader, train_y):
        drug = drug.to(device)
        protein = protein.to(device)
        true_y = true_y.to(device)
        
        optimizer.zero_grad()
        
        output = model(drug, protein)
        loss = criterion(output, true_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    average_loss = total_loss / len(drug_train_loader)
    train_losses.append(average_loss)
    
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for drug, protein, true_y in zip(drug_val_loader, protein_val_loader, val_y):
            drug = drug.to(device)
            protein = protein.to(device)
            true_y = true_y.to(device)            
            
            output = model(drug, protein)
            loss = criterion(output, true_y)
            val_loss += loss.item()
            val_losses.append(val_loss)

        average_val_loss = val_loss / len(drug_val_loader)
        val_losses.append(average_val_loss)
    print(f'Epoch: {epoch+1}, Train loss: {average_loss:.4f}, Validation Loss: {average_val_loss:.4f}')

 10%|█         | 1/10 [03:15<29:18, 195.43s/it]

Epoch: 1, Train loss: 9704670802872118097739776.0000, Validation Loss: 11995704253833604329635840.0000


 20%|██        | 2/10 [06:23<25:29, 191.18s/it]

Epoch: 2, Train loss: 13426422195643830272786432.0000, Validation Loss: 15818842126724574088265728.0000


 30%|███       | 3/10 [09:31<22:08, 189.85s/it]

Epoch: 3, Train loss: 15072883498246984442052608.0000, Validation Loss: 16410428067776149590114304.0000


 40%|████      | 4/10 [12:42<19:01, 190.32s/it]

Epoch: 4, Train loss: 15424129192380555203706880.0000, Validation Loss: 16313776548198604091686912.0000


 50%|█████     | 5/10 [15:54<15:53, 190.74s/it]

Epoch: 5, Train loss: 15333454150872954572898304.0000, Validation Loss: 16241786953471018382917632.0000


 60%|██████    | 6/10 [19:03<12:41, 190.31s/it]

Epoch: 6, Train loss: 15265825856519578104889344.0000, Validation Loss: 16173818864243651387588608.0000


 70%|███████   | 7/10 [22:14<09:30, 190.33s/it]

Epoch: 7, Train loss: 15200319444435324029108224.0000, Validation Loss: 16106919441017334484959232.0000


 80%|████████  | 8/10 [25:24<06:20, 190.25s/it]

Epoch: 8, Train loss: 15137775209762522715389952.0000, Validation Loss: 16044476970214310654509056.0000


In [None]:
model.eval()
pred = []
with torch.no_grad():
    test_loss = 0
    for drug, protein, true_y in zip(drug_test_loader, protein_test_loader, test_y):
        drug = drug.to(device)
        protein = protein.to(device)
        true_y = true_y.to(device)            

        output = model(drug, protein)
        loss = criterion(output, true_y)
        test_loss += loss.item()
        pred.append(output.cpu().detach().numpy()) 
        
pred = np.concatenate(pred)

In [None]:
mse_loss(torch.Tensor(pred), torch.Tensor(np.array(test['Y'])))