In [None]:
# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!pip install pytorch-lightning
!pip install PyTDC
# !pip install torch==1.9
# !pip install torchtext==0.10

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
# import torch_xla.core.xla_model as xm

from torchmetrics import MeanAbsoluteError

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import numpy as np
import pandas as pd

from tdc.multi_pred import DTI

BATCH_SIZE = 256

davis = DTI(name="DAVIS")
davis.convert_to_log(form="binding")
davis_split = davis.get_split()

train_df = davis_split["train"].reset_index(drop=True)
valid_df = davis_split["valid"].reset_index(drop=True)
test_df = davis_split["test"].reset_index(drop=True)

print(f"train: {train_df.shape} valid: {valid_df.shape} test: {test_df.shape}")

In [3]:
def generate_vocab(corpus):
    token_index = 1
    stoi, itos = {}, {}
    stoi["<PAD>"] = 0
    itos[0] = "<PAD>"

    for line in corpus:
        for token in line:
            if token not in stoi:
                itos[token_index] = token
                stoi[token] = token_index
                token_index += 1

    return stoi, itos

drug_stoi, drug_itos = generate_vocab(train_df['Drug'].values)
target_stoi, target_itos = generate_vocab(train_df['Target'].values)

drug_vocab_dim = len(drug_stoi)
target_vocab_dim = len(target_stoi)

In [4]:
class DTIDataset(Dataset):
    def __init__(self, drug, target, y, drug_stoi, target_stoi):
        self.drug = drug
        self.target = target

        self.drug_max_seq_len = 85
        self.target_max_seq_len = 1200

        self.y = y
        self.drug_stoi = drug_stoi
        self.target_stoi = target_stoi


    def __len__(self):
        return self.drug.shape[0]
    

    def __getitem__(self, idx):
        drug = [self.drug_stoi[s] for s in self.drug[idx]]
        if len(drug) < self.drug_max_seq_len:
            drug_padding = self.drug_max_seq_len - len(drug)
            drug += [0] * drug_padding
        else:
            drug = drug[:self.drug_max_seq_len]

        target = [self.target_stoi[s] for s in self.target[idx]]
        if len(target) < self.target_max_seq_len:
            target_padding = self.target_max_seq_len - len(target)
            target += [0] * target_padding
        else:
            target = target[:self.target_max_seq_len]

        y = self.y[idx]

        return torch.tensor(drug).float(), torch.tensor(target).float(), torch.tensor(y).float()
    
train_dataset = DTIDataset(train_df["Drug"], train_df["Target"], train_df["Y"], drug_stoi, target_stoi)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, shuffle=True)

valid_dataset = DTIDataset(valid_df["Drug"], valid_df["Target"], valid_df["Y"], drug_stoi, target_stoi)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, shuffle=True)

test_dataset = DTIDataset(test_df["Drug"], test_df["Target"], test_df["Y"], drug_stoi, target_stoi)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, shuffle=False)

In [21]:
class Embedding(nn.Module):
    def __init__(self, vocab_dim, embedding_dim, dropout_rate):
        super(Embedding, self).__init__()
        self.vocab_dim = vocab_dim
        self.embedding_dim = embedding_dim
        self.dropout_rate = dropout_rate
        
        self.embedding = nn.Embedding(self.vocab_dim, self.embedding_dim)
        self.dropout = nn.Dropout(self.dropout_rate)
        
    def forward(self, x):
        embedding = self.embedding(x.long())
        embedding = self.dropout(embedding)
        
        return embedding


class Encoder(nn.Module):
    def __init__(self, vocab_dim, embedding_dim, kernel_size, dropout_rate=0.1):
        super(Encoder, self).__init__()
        self.embedding = Embedding(vocab_dim, embedding_dim, dropout_rate)
        
        self.conv1 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=kernel_size[0], padding=1)
        self.activation1 = nn.ReLU()
        
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=kernel_size[1], padding=1)        
        self.activation2 = nn.ReLU()

        self.conv3 = nn.Conv1d(in_channels=64, out_channels=96, kernel_size=kernel_size[2], padding=1)
        self.activation3 = nn.ReLU()

        self.fc = nn.Linear(96, 256)

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.embedding(x)
        x = x.moveaxis(1, 2)

        x = self.activation1(self.conv1(x))
        x = self.activation2(self.conv2(x))
        x = self.activation3(self.conv3(x))
        x = F.adaptive_max_pool1d(x, output_size=1)

        x = x.view(batch_size, -1)

        x = self.fc(x.float())

        return x


class DTI(nn.Module):
    def __init__(self, drug_vocab_dim, target_vocab_dim, embedding_dim, dropout_rate=0.1):
        super(DTI, self).__init__()
        drug_encoder_kernel_size = [4, 6, 8]
        target_encoder_kernel_size = [4, 8, 12]

        self.drug_encoder = Encoder(drug_vocab_dim, embedding_dim, drug_encoder_kernel_size, dropout_rate=dropout_rate)
        self.target_encoder = Encoder(drug_vocab_dim, embedding_dim, target_encoder_kernel_size, dropout_rate=dropout_rate)

        self.fc1 = nn.Linear(256 * 2, 1024)
        self.activation1 = nn.ReLU()
        self.fc2 = nn.Linear(1024, 1024)
        self.activation2 = nn.ReLU()
        self.fc3 = nn.Linear(1024, 512)
        self.activation3 = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

        self.out = nn.Linear(512, 1)


    def forward(self, drug, target):
        drug_encoding = self.drug_encoder(drug)
        target_encoding = self.target_encoder(target)

        x = torch.cat((drug_encoding, target_encoding), axis=1)

        x = self.activation1(self.fc1(x))
        x = self.activation2(self.fc2(x))
        x = self.dropout(self.activation3(self.fc3(x)))
        
        out = self.out(x)
        
        return out

In [22]:
class DeepDTI(pl.LightningModule):
    def __init__(self, model, learning_rate):
        super(DeepDTI, self).__init__()
        self.model = model
        self.learning_rate = learning_rate


    def training_step(self, batch, batch_idx):
        drug, target, y = batch
        y_hat = self.model(drug, target)
        loss = F.mse_loss(y_hat, y)
        self.log("train_loss", loss)

        return loss
    
        
    def validation_step(self, batch, batch_idx):
        drug, target, y = batch
        y_hat = self.model(drug, target)
        loss = F.mse_loss(y_hat, y)
        self.log("valid_loss", loss, on_epoch=True, prog_bar=True, logger=True)
 

    def test_step(self, batch, batch_idx):
        drug, target, y = batch
        y_hat = self.model(drug, target)
        loss = F.mse_loss(y_hat, y)
        self.log("test_loss", loss)

        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10)
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "valid_loss"}

    
def define_callbacks(patience):
    return EarlyStopping('valid_loss', patience=patience)

In [None]:
prediction_head = DTI(drug_vocab_dim, target_vocab_dim, embedding_dim=32)
model = DeepDTI(prediction_head, 0.001)
callbacks = define_callbacks(patience=30)
# trainer = pl.Trainer(accelerator="cpu", num_processes=1, max_epochs=1, enable_progress_bar=True)
trainer = pl.Trainer(gpus=1, max_epochs=100, enable_progress_bar=True, callbacks=callbacks, default_root_dir="drive/MyDrive/DeepDTA_CKPT")
trainer.fit(model, train_dataloader, valid_dataloader)

In [None]:
trainer.test(model, test_dataloader)