In [1]:
import pandas as pd
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem.QED import qed
from tdc.multi_pred import DTI

from rdkit import RDLogger 
RDLogger.DisableLog('rdApp.*')

import warnings
warnings.filterwarnings(action='ignore')

# davis = DTI(name="Davis")
# davis.convert_to_log(form='binding')
# davis_split = davis.get_split()

kiba = DTI(name="KIBA")
kiba_split = kiba.get_split()

train_df = kiba_split['train']
valid_df = kiba_split['valid']
test_df = kiba_split['test']

Found local copy...
Loading...
Done!


In [3]:
from transformers import BertModel, BertTokenizer
from transformers import PreTrainedTokenizerFast, PreTrainedTokenizer

molecule_tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="data/drug/tokenizer_model/vocab.json",
    pad_token="[PAD]",
    mask_token="[MASK]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    unk_token="[UNK]",
    model_max_length=128
)
molecule_bert = BertModel.from_pretrained("weights/molecule_bert_pretrained-masking_rate_30", local_files_only=True)

Some weights of BertModel were not initialized from the model checkpoint at weights/molecule_bert_pretrained-masking_rate_30 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
fasta_stoi = {
    "A": 0, "B": 1, "C": 2, "D": 3, "E": 4, "F": 5, "G": 6, "H": 7,
    "I": 8, "J": 9, "K": 10, "L": 11, "M": 12, "N": 13, "O": 14,
    "P": 15, "Q": 16, "R": 17, "S": 18, "T": 19, "U": 20, "V": 21, 
    "W": 22, "Y": 23, "Z": 24, "X": 25, "*": 26, "-": 27
}

fasta_itos = {
    0: "A", 1: "B", 2: "C", 3: "D", 4: "E", 5: "F", 6: "G", 7: "H",
    8: "I", 9: "J", 10: "K", 11: "L", 12: "M", 13: "N", 14: "O",
    15: "P", 16: "Q", 17: "R", 18: "S", 19: "T", 20: "U", 21: "V", 
    22: "W", 23: "Y", 24: "Z", 25: "X", 26: "*", 27: "-" 
}

def encode(data, stoi):
    return [stoi[d] for d in data]


def decode(data, itos):
    return [itos[d] for d in data]

sample_fasta = "MTEITAAMVKELRESTGAGMMDCKNALSETNGDFDKAVQLLREKGLGKAAKKADRLAAEGLVSVKVSDDFTIAAMRPSYLSYEDLDMTFVENEYKALVAELEKENEERRRLKDPNKPEHKIPQFASRKQLSDAILKEAEEKIKEELKAQGKPEKIWDNIIPGKMNSFIADNSQLDSKLTLMGQFYVMDDKKTVEQVIAEKEKEFGGKIKIVEFICFEVGEGLEKKTEDFAAEVAAQL"
print(len(sample_fasta))
print(len(encode(sample_fasta, fasta_stoi)))

237
237


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, RandomSampler


class DTIDataset(Dataset):
    def __init__(self, data, molecule_tokenizer, fasta_stoi):
        self.data = data
        
        self.molecule_tokenizer = molecule_tokenizer
        self.fasta_stoi = fasta_stoi
    
        
    def molecule_encode(self, molecule_sequence):
        molecule_sequence = self.molecule_tokenizer(" ".join(molecule_sequence), max_length=128, truncation=True)
        
        return molecule_sequence
    
    
    def protein_encode(self, protein_sequence):
        protein_sequence = torch.tensor([self.fasta_stoi[s] for s in protein_sequence]).long()
        
        return protein_sequence
        
        
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        molecule_sequence = self.molecule_encode(self.data.loc[idx, "Drug"])
        protein_sequence = self.protein_encode(self.data.loc[idx, "Target"])
        y = torch.tensor(self.data.loc[idx, "Y"]).float()
                
        return molecule_sequence, protein_sequence, y




def collate_batch(batch):
    molecule_seq, protein_seq, y = [], [], []
    
    for (molecule_seq_, protein_seq_, y_) in batch:
        molecule_seq.append(molecule_seq_)
        
        if len(protein_seq_) <= 2048:
            protein_seq.append(protein_seq_)
        else:
            protein_seq.append(protein_seq_[:2048])
            
        y.append(y_)
        
    molecule_seq = molecule_tokenizer.pad(molecule_seq, return_tensors="pt")
    protein_seq = pad_sequence(protein_seq, batch_first=True, padding_value=0)
    y = torch.tensor(y).float()
    
    
    return molecule_seq, protein_seq, y


train_dataset = DTIDataset(train_df, molecule_tokenizer, fasta_stoi)
train_sampler = RandomSampler(train_df, replacement=True, num_samples=12800)
valid_dataset = DTIDataset(valid_df, molecule_tokenizer, fasta_stoi)
test_dataset = DTIDataset(test_df, molecule_tokenizer, fasta_stoi)

train_dataloader = DataLoader(train_dataset, batch_size=128, num_workers=16, 
                              shuffle=False, pin_memory=True, prefetch_factor=10, 
                              drop_last=True, collate_fn=collate_batch, sampler=train_sampler)
valid_dataloader = DataLoader(valid_dataset, batch_size=128, num_workers=16, 
                              shuffle=False, pin_memory=True, prefetch_factor=10, 
                              drop_last=False, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=128, num_workers=16, 
                             shuffle=False, pin_memory=True, prefetch_factor=10, 
                             drop_last=False, collate_fn=collate_batch)


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ProteinCNNEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dims=[1024, 512, 256], dropout=0.1):
        super().__init__()
        self.dropout = dropout
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.conv_block_1_layer_1 = nn.Conv1d(embedding_dim, hidden_dims[0], kernel_size=5, padding=2)
        self.conv_block_1_layer_2 = nn.Conv1d(hidden_dims[0], hidden_dims[0], kernel_size=5, padding=2)
        
        self.conv_block_2_layer_1 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=3, padding=1)
        self.conv_block_2_layer_2 = nn.Conv1d(hidden_dims[1], hidden_dims[1], kernel_size=3, padding=1)
        
        self.conv_block_3_layer_1 = nn.Conv1d(hidden_dims[1], hidden_dims[2], kernel_size=3, padding=1)
        self.conv_block_3_layer_2 = nn.Conv1d(hidden_dims[2], hidden_dims[2], kernel_size=3, padding=1)
    
    
    def forward(self, x):
        x = self.embedding(x)
        x = x.moveaxis(1, 2)
        
        x = F.dropout(F.gelu(self.conv_block_1_layer_1(x)), self.dropout)
        x = F.dropout(F.gelu(self.conv_block_1_layer_2(x)), self.dropout)
        x = F.max_pool1d(x, 2)       
        
        x = F.dropout(F.gelu(self.conv_block_2_layer_1(x)), self.dropout)
        x = F.dropout(F.gelu(self.conv_block_2_layer_2(x)), self.dropout)
        x = F.max_pool1d(x, 2)
        
        x = F.dropout(F.gelu(self.conv_block_3_layer_1(x)), self.dropout)
        x = F.dropout(F.gelu(self.conv_block_3_layer_2(x)), self.dropout)
        x = F.max_pool1d(x, 2)
        
        x, _ = torch.max(x, -1)
        
        return x

vocab_size = 1024
embedding_dim = 256
protein_cnn_encoder = ProteinCNNEncoder(vocab_size, embedding_dim)

# seq_len = 100

# sample_input = torch.randint(vocab_size, (1, seq_len)).long()
# out = protein_cnn_encoder(sample_input)
# out.shape

In [7]:
class DTIPredictionHead(nn.Module):
    def __init__(self, molecule_encoder, protein_encoder, 
                 molecule_dim=128, protein_dim=256, inner_dim=256, projection=True):
        super().__init__()
        self.is_projection = projection
        self.molecule_encoder = molecule_encoder
        self.protein_encoder = protein_encoder
        
        # model freezing without last layer
        for param in self.molecule_encoder.encoder.layer[0:-1].parameters():
            param.requires_grad = False
        
        if self.is_projection:
            self.mol_proj = nn.Linear(molecule_dim, inner_dim, bias=False)        
            self.prot_proj = nn.Linear(protein_dim, inner_dim, bias=False)            
            self.fc_1 = nn.Linear(inner_dim * 2, inner_dim)
        else:
            self.fc_1 = nn.Linear(molecule_dim + protein_dim, inner_dim)
        
        self.fc_2 = nn.Linear(inner_dim, int(inner_dim / 2))
        self.fc_out = nn.Linear(int(inner_dim / 2), 1)
        
        
    def forward(self, molecule, protein):
        molecule = self.molecule_encoder(**molecule).pooler_output
        protein = self.protein_encoder(protein)
        
        if self.is_projection:
            molecule = self.mol_proj(molecule)
            protein = self.prot_proj(protein)
            
        x = torch.cat((molecule, protein), -1)
        x = F.dropout(F.gelu(self.fc_1(x)), 0.1)
        x = F.dropout(F.gelu(self.fc_2(x)), 0.1)
        x = self.fc_out(x)
        
        return x


dti_prediction_head = DTIPredictionHead(molecule_bert, protein_cnn_encoder, projection=True)
# dti_prediction_head

In [8]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics.functional import mean_squared_error, mean_absolute_error


class DTI_prediction(pl.LightningModule):
    def __init__(self, dti_prediction_head):
        super().__init__()
        self.model = dti_prediction_head

        
    def forward(self, molecule_sequence, protein_sequence):
        return self.model(molecule_sequence, protein_sequence)
    
    
    def training_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def test_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        
        return y_hat
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=200)
    
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    
    
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=30, dirpath='weights/DTI_prediction_CLS_concatenate_with_projection', filename='attentional_dti-{epoch:03d}-{valid_loss:.4f}-{valid_mae:.4f}'),
]

model = DTI_prediction(dti_prediction_head)

trainer = pl.Trainer(max_epochs=1000, gpus=1, enable_progress_bar=True, callbacks=callbacks, precision=16)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model, train_dataloader, valid_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type              | Params
--------------------------------------------
0 | model | DTIPredictionHead | 11.7 M
--------------------------------------------
10.3 M    Trainable params
1.4 M     Non-trainable params
11.7 M    Total params
23.319    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]