In [1]:
device_no = 0
max_length = 256

import pickle
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
from transformers import AutoModel, BertTokenizer, RobertaTokenizer
from transformers import BertConfig, BertModel

df = pd.read_excel("data/validation.xlsx", engine="openpyxl")
df

Unnamed: 0,Drug_name,SMILES,Target_name,FASTA
0,Phenacetin,CCOC1=CC=C(C=C1)NC(=O)C,1A2,MAVLKGLRPRVPKGLKSPPEPWGWPLLGHVLTLGKNPHLALSRMSQ...
1,7-Ethoxyresorufin,CCOC1=CC2=C(C=C1)N=C3C=CC(=O)C=C3O2,1A2,MAVLKGLRPRVPKGLKSPPEPWGWPLLGHVLTLGKNPHLALSRMSQ...
2,Efavirenz,C1CC1C#CC2(C3=C(C=CC(=C3)Cl)NC(=O)O2)C(F)(F)F,2B6,MAKKTSSKGKLPPGPRPLPLLGNLLQMDRRGLLKSFLRFREKYGDV...
3,Bupropion,CC(C(=O)C1=CC(=CC=C1)Cl)NC(C)(C)C\n,2B6,MAKKTSSKGKLPPGPRPLPLLGNLLQMDRRGLLKSFLRFREKYGDV...
4,Midazolam,CC1=NC=C2N1C3=C(C=C(C=C3)Cl)C(=NC2)C4=CC=CC=C4F,3A4,MAYLYGTHSHGLFKKLGIPGPTPLPFLGNILSYHKGFCMFDMECHK...
5,Testosterone,CC12CCC3C(C1CCC2O)CCC4=CC(=O)CCC34C,3A4,MAYLYGTHSHGLFKKLGIPGPTPLPFLGNILSYHKGFCMFDMECHK...
6,Paclitaxel,CC1=C2C(C(=O)C3(C(CC4C(C3C(C(C2(C)C)(CC1OC(=O)...,2C8,MAKKTSSKGKLPPGPTPLPIIGNMLQIDVKDICKSFTNFSKVYGPV...
7,Amodiaquine,CCN(CC)CC1=C(C=CC(=C1)NC2=C3C=CC(=CC3=NC=C2)Cl)O,2C8,MAKKTSSKGKLPPGPTPLPIIGNMLQIDVKDICKSFTNFSKVYGPV...
8,S-Warfarin,CC(=O)CC(C1=CC=CC=C1)C2=C(C3=CC=CC=C3OC2=O)O,2C9,MAKKTSGRGKLPPGPTPLPVIGNILQIGIKDISKSLTNLSKVYGPV...
9,Diclofenac,C1=CC=C(C(=C1)CC(=O)O)NC2=C(C=CC=C2Cl)Cl,2C9,MAKKTSGRGKLPPGPTPLPVIGNILQIGIKDISKSLTNLSKVYGPV...


In [2]:
class Preprocessor:
    def __init__(self, mol_tokenizer, prot_tokenizer, device_no, prot_max_length=256):
        self.prot_tokenizer = prot_tokenizer
        self.mol_tokenizer = mol_tokenizer
        self.prot_encoder = AutoModel.from_pretrained("Rostlab/prot_bert")
        self.prot_encoder.eval()
        
        self.device_no = device_no
        self.prot_max_length = prot_max_length
        
    def encode_fasta(self, fasta):
        target_seq = self.prot_tokenizer(" ".join(fasta), max_length=self.prot_max_length, return_tensors="pt")
        
        return target_seq
    
    def encode_smiles(self, smiles):
        drug_seq = self.mol_tokenizer(smiles, max_length=512, return_tensors="pt")
        
        return drug_seq
    
    def get_hint(self, fasta):
        target_seq = self.encode_fasta(fasta)
        hint = self.prot_encoder(**target_seq)
        hint = hint.last_hidden_state.detach().to("cpu")
        
        return hint[:, 0]
    
    def get_feat(self, fasta, smiles):
        hint = self.get_hint(fasta)
        target_seq = self.encode_fasta(fasta)
        drug_seq = self.encode_smiles(smiles)
        
        return hint, target_seq, drug_seq

In [3]:
def load_tokenizer():
    prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
    mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

    return prot_tokenizer, mol_tokenizer

prot_tokenizer, mol_tokenizer = load_tokenizer()
mol_encoder = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

config = BertConfig(
    vocab_size=prot_tokenizer.vocab_size,
    hidden_size=512,
    num_hidden_layers=4,
    num_attention_heads=4,
    intermediate_size=2048,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=max_length + 2,
    type_vocab_size=1,
    pad_token_id=0,
    position_embedding_type="absolute"
)

prot_encoder = BertModel(config)

Some weights of the model checkpoint at seyonec/ChemBERTa-zinc-base-v1 were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
class DTI(nn.Module):
    def __init__(self, mol_encoder, prot_encoder, 
                 hidden_dim=512, mol_dim=128, prot_dim=1024):
        super().__init__()
        self.mol_encoder = mol_encoder
        self.prot_encoder = prot_encoder
        
        self.lambda_ = torch.nn.Parameter(torch.rand(1).to(f"cuda:{device_no}"), requires_grad=True)
                    
        self.molecule_align = nn.Sequential(
            nn.LayerNorm(mol_dim),
            nn.Linear(mol_dim, hidden_dim, bias=False)
        )
        
        self.protein_align_teacher = nn.Sequential(
            nn.LayerNorm(1024),
            nn.Linear(1024, hidden_dim, bias=False)
        )
        
        self.protein_align_student = nn.Sequential(
            nn.LayerNorm(prot_dim),
            nn.Linear(prot_dim, hidden_dim, bias=False)
        )
        
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim * 4)
        self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim * 2)
        self.fc3 = nn.Linear(hidden_dim * 2, hidden_dim)
        
        self.cls_out = nn.Linear(hidden_dim, 1)
        
    def forward(self, SMILES, FASTA, prot_feat_teacher):
        mol_feat = self.mol_encoder(**SMILES).last_hidden_state[:, 0]
        prot_feat = self.prot_encoder(**FASTA).last_hidden_state[:, 0]
        
        mol_feat = self.molecule_align(mol_feat)
        prot_feat = self.protein_align_student(prot_feat)
        prot_feat_teacher = self.protein_align_teacher(prot_feat_teacher).squeeze(1)
        
        lambda_ = torch.sigmoid(self.lambda_)
        merged_prot_feat = lambda_ * prot_feat + (1 - lambda_) * prot_feat_teacher
    
        x = torch.cat([mol_feat, merged_prot_feat], dim=1)

        x = F.dropout(F.gelu(self.fc1(x)), 0.1)
        x = F.dropout(F.gelu(self.fc2(x)), 0.1)
        x = F.dropout(F.gelu(self.fc3(x)), 0.1)
        
        cls_out = self.cls_out(x).squeeze(-1)
        
        return cls_out, lambda_
        
model = DTI(mol_encoder, prot_encoder,
            hidden_dim=512, mol_dim=768, prot_dim=512)

model.load_state_dict(torch.load("weights/DLM_DTI_prot-256_.pt"))
model = model.to(f"cuda:{device_no}")
model.eval()

DTI(
  (mol_encoder): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(767, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps

In [7]:
import time

preprocessor = Preprocessor(mol_tokenizer, prot_tokenizer, device_no, prot_max_length=256)

time_list = []
for i in range(len(df)):
    
    start = time.time()
    hint, target_seq, drug_seq = preprocessor.get_feat(df.loc[i, "FASTA"], df.loc[i, "SMILES"])
    prob, _ = model(drug_seq.to(f"cuda:{device_no}"), target_seq.to(f"cuda:{device_no}"), hint.to(f"cuda:{device_no}"))
    end = time.time()
    
    time_list.append(end - start)
    print(F.sigmoid(prob))
    
time_list = np.array(time_list)

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor([0.9986], device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([1.0000], device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([0.9997], device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([0.9987], device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([0.9930], device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([0.9996], device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([0.9972], device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([0.9993], device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([0.9999], device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([0.9994], device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([0.8251], device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([1.0000], device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([1.0000], device='cuda:0', grad_fn=<SigmoidBackward0>)


In [9]:
np.mean(time_list), np.std(time_list)

(0.35838868067814755, 0.004643884170554766)