In [1]:
device_no = 0
max_length = 512
fold_num = 1
PROJECT_NAME = f"prot-{max_length}_fold-{fold_num}"

import pickle
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
import torchmetrics

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(name=f'{PROJECT_NAME}',
                           project='DLM_DTI_hint_based_learning_new')

import transformers
from transformers import AutoModel, BertTokenizer, RobertaTokenizer
from transformers import BertConfig, BertModel

train_df = pd.read_csv("data/train_dataset.csv")
valid_df = pd.read_csv("data/valid_dataset.csv")
test_df = pd.read_csv("data/test_dataset.csv")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjonghyunlee1993[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
with open(f"prot_feat/{max_length}_cls.pkl", "rb") as f:
    prot_feat_teacher = pickle.load(f)

class DTIDataset(Dataset):
    def __init__(self, data, prot_feat_teacher, 
                 mol_tokenizer, prot_tokenizer, max_length):
        self.data = data
        self.prot_feat_teacher = prot_feat_teacher
        self.max_length = max_length
        self.mol_tokenizer = mol_tokenizer
        self.prot_tokenizer = prot_tokenizer
        
    def get_mol_feat(self, smiles):
        return self.mol_tokenizer(smiles, max_length=512, truncation=True)
    
    def get_prot_feat_teacher(self, fasta):
        return self.prot_tokenizer(" ".join(fasta), max_length=self.max_length, truncation=True)
    
    def get_prot_feat_student(self, fasta):
        return self.prot_feat_teacher[fasta[:20]]
    
    def __len__(self):    
        return len(self.data)
    
    def __getitem__(self, index):
        smiles = self.data.loc[index, "SMILES"]
        mol_feat = self.get_mol_feat(smiles)
        
        fasta = self.data.loc[index, "Target Sequence"]
        prot_feat_student = self.get_prot_feat_teacher(fasta)
        prot_feat_teacher = self.get_prot_feat_student(fasta)
        
        y = self.data.loc[index, "Label"]
        source = self.data.loc[index, "Source"]
        if source == "DAVIS":
            source = 1
        elif source == "BindingDB":
            source = 2
        elif source == "BIOSNAP":
            source = 3
                
        return mol_feat, prot_feat_student, prot_feat_teacher, y, source

    
def collate_batch(batch):
    mol_features, prot_feat_student, prot_feat_teacher, y, source = [], [], [], [], []
    
    for (mol_seq, prot_seq_student, prot_seq_teacher, y_, source_) in batch:
        mol_features.append(mol_seq)
        prot_feat_student.append(prot_seq_student)
        prot_feat_teacher.append(prot_seq_teacher.detach().cpu().numpy().tolist())
        y.append(y_)
        source.append(source_)
        
    mol_features = mol_tokenizer.pad(mol_features, return_tensors="pt")
    prot_feat_student = prot_tokenizer.pad(prot_feat_student, return_tensors="pt")
    prot_feat_teacher = torch.tensor(prot_feat_teacher).float()
    y = torch.tensor(y).float()
    source = torch.tensor(source)
    
    return mol_features, prot_feat_student, prot_feat_teacher, y, source

In [3]:
mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
mol_encoder = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
# prot_encoder = AutoModel.from_pretrained("Rostlab/prot_bert")

for param in mol_encoder.embeddings.parameters():
    param.requires_grad = False

for layer in mol_encoder.encoder.layer[:6]:
    for param in layer.parameters():
        param.requires_grad = False

config = BertConfig(
    vocab_size=prot_tokenizer.vocab_size,
    hidden_size=512,
    num_hidden_layers=4,
    num_attention_heads=4,
    intermediate_size=2048,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=max_length + 2,
    type_vocab_size=1,
    pad_token_id=0,
    position_embedding_type="absolute"
)

prot_encoder = BertModel(config)

Some weights of the model checkpoint at seyonec/ChemBERTa-zinc-base-v1 were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.decoder.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
train_dataset = DTIDataset(train_df, prot_feat_teacher, 
                           mol_tokenizer, prot_tokenizer, max_length)
valid_dataset = DTIDataset(valid_df, prot_feat_teacher, 
                           mol_tokenizer, prot_tokenizer, max_length)
test_dataset = DTIDataset(test_df, prot_feat_teacher, 
                          mol_tokenizer, prot_tokenizer, max_length)

counts = np.bincount(train_df["Label"])
labels_weights = 1. / counts
weights = labels_weights[train_df["Label"]]
sampler = WeightedRandomSampler(weights, len(weights))

train_dataloader = DataLoader(train_dataset, batch_size=128, num_workers=32, 
                              pin_memory=True, prefetch_factor=10, drop_last=True, 
                              sampler=sampler, collate_fn=collate_batch)

valid_dataloader = DataLoader(train_dataset, batch_size=128, num_workers=32, 
                              pin_memory=True, prefetch_factor=10, 
                              drop_last=False, collate_fn=collate_batch)

test_dataloader = DataLoader(train_dataset, batch_size=128, num_workers=32, 
                             pin_memory=True, prefetch_factor=10, 
                             drop_last=False, collate_fn=collate_batch)

In [5]:
class DTI(nn.Module):
    def __init__(self, mol_encoder, prot_encoder, 
                 hidden_dim=512, mol_dim=128, prot_dim=1024):
        super().__init__()
        self.mol_encoder = mol_encoder
        self.prot_encoder = prot_encoder
        
        self.lambda_ = torch.nn.Parameter(torch.rand(1).to(f"cuda:{device_no}"), requires_grad=True)
                    
        self.molecule_align = nn.Sequential(
            nn.LayerNorm(mol_dim),
            nn.Linear(mol_dim, hidden_dim, bias=False)
        )
        
        self.protein_align_teacher = nn.Sequential(
            nn.LayerNorm(1024),
            nn.Linear(1024, hidden_dim, bias=False)
        )
        
        self.protein_align_student = nn.Sequential(
            nn.LayerNorm(prot_dim),
            nn.Linear(prot_dim, hidden_dim, bias=False)
        )
        
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim * 4)
        self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim * 2)
        self.fc3 = nn.Linear(hidden_dim * 2, hidden_dim)
        
        self.cls_out = nn.Linear(hidden_dim, 1)
        
    def forward(self, SMILES, FASTA, prot_feat_teacher):
        mol_feat = self.mol_encoder(**SMILES).last_hidden_state[:, 0]
        prot_feat = self.prot_encoder(**FASTA).last_hidden_state[:, 0]
        
        mol_feat = self.molecule_align(mol_feat)
        prot_feat = self.protein_align_student(prot_feat)
        prot_feat_teacher = self.protein_align_teacher(prot_feat_teacher).squeeze(1)
        
        lambda_ = torch.sigmoid(self.lambda_)
        merged_prot_feat = lambda_ * prot_feat + (1 - lambda_) * prot_feat_teacher
    
        x = torch.cat([mol_feat, merged_prot_feat], dim=1)

        x = F.dropout(F.gelu(self.fc1(x)), 0.1)
        x = F.dropout(F.gelu(self.fc2(x)), 0.1)
        x = F.dropout(F.gelu(self.fc3(x)), 0.1)
        
        cls_out = self.cls_out(x).squeeze(-1)
        
        return cls_out, lambda_
        
model = DTI(mol_encoder, prot_encoder,
            hidden_dim=512, mol_dim=768, prot_dim=512)

In [6]:
class DTI_prediction(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    
    def step(self, batch):
        mol_feature, prot_feat_student, prot_feat_teacher, y, source = batch
        prot_feat_teacher = prot_feat_teacher.detach()
        pred, lambda_ = self.model(mol_feature, prot_feat_student, prot_feat_teacher)
        
        loss = F.binary_cross_entropy_with_logits(pred, y)
        
        pred = F.sigmoid(pred)
        auroc = torchmetrics.functional.auroc(pred, y.long())
        auprc = torchmetrics.functional.average_precision(pred, y.long())
        
        return pred, y, source, loss, lambda_
        
    
    def training_step(self, batch, batch_idx):
        _, _, _, loss, lambda_ = self.step(batch)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_lambda_', lambda_, on_step=True, on_epoch=True, prog_bar=False)
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        preds, y, _, loss, lambda_ = self.step(batch)
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('valid_lambda_', lambda_, on_step=False, on_epoch=True, prog_bar=True)
        
        return {'preds': preds, 'target': y}


    def validation_epoch_end(self, outputs):
        preds = torch.cat([tmp['preds'] for tmp in outputs], 0).detach().cpu()
        targets = torch.cat([tmp['target'] for tmp in outputs], 0).detach().cpu().long()

        auroc = torchmetrics.functional.auroc(preds, targets.long())
        auprc = torchmetrics.functional.average_precision(preds, targets.long())
        self.log('valid_auroc', auroc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('valid_auprc', auprc, on_step=False, on_epoch=True, prog_bar=True)
        
#         conf_mat = torchmetrics.functional.confusion_matrix(preds, targets, num_classes=2)

#         print(f'Epoch : {self.trainer.current_epoch}')
#         print(conf_mat)
    
    
    def test_step(self, batch, batch_idx):
        preds, y, _, loss, _ = self.step(batch)
        self.logging([loss, auroc, auprc], mode='test')
        
        return {'preds': preds, 'target': y}
    
    
    def test_epoch_end(self, outputs):
        preds = torch.cat([tmp['preds'] for tmp in outputs], 0).detach().cpu()
        targets = torch.cat([tmp['target'] for tmp in outputs], 0).detach().cpu().long()
        
        auroc = torchmetrics.functional.auroc(preds, targets.long())
        auprc = torchmetrics.functional.average_precision(preds, targets.long())
        self.log('test_auroc', auroc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_auprc', auprc, on_step=False, on_epoch=True, prog_bar=True)
        
        conf_mat = torchmetrics.functional.confusion_matrix(preds, targets, num_classes=2)

        print(conf_mat)
    
    
    def predict_step(self, batch, batch_idx):
        pred, y, source, _, _ = self.step(batch)
        
        return pred, y, source

    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10*len(train_dataloader))
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
          
    
callbacks = [
    ModelCheckpoint(monitor='valid_loss',
                    save_top_k=1, dirpath=f'weights/{PROJECT_NAME}', 
                    filename='DTI-{epoch:03d}-{valid_loss:.4f}-{valid_auroc:.4f}-{valid_auprc:.4f}'),
]

predictor = DTI_prediction(model)
trainer = pl.Trainer(max_epochs=30, gpus=[device_no], enable_progress_bar=True, 
                     callbacks=callbacks, precision=16, logger=wandb_logger)


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(predictor, train_dataloader, valid_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type | Params
-------------------------------
0 | model | DTI  | 63.2 M
-------------------------------
19.6 M    Trainable params
43.5 M    Non-trainable params
63.2 M    Total params
126.326   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [8]:
predictor = predictor.load_from_checkpoint(f"weights/{PROJECT_NAME}/DTI-epoch=029-valid_loss=0.2288-valid_auroc=0.9699-valid_auprc=0.9696.ckpt",
                                          model=model)
out = trainer.predict(predictor, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 265it [00:00, ?it/s]

In [9]:
davis_pred, davis_target = [], []
binding_pred, binding_target = [], []
biosnap_pred, biosnap_target = [], []

for batch in out:
    for i in range(batch[0].shape[0]):
        pred = batch[0][i].detach().numpy().tolist()
        target = batch[1][i].detach().numpy().tolist()
        source = batch[2][i].detach().numpy().tolist()

        if source == 1:
            davis_pred.append(pred)
            davis_target.append(target)
        elif source == 2:
            binding_pred.append(pred)
            binding_target.append(target)
        elif source == 3:
            biosnap_pred.append(pred)
            biosnap_target.append(target)        

In [10]:
from sklearn.metrics import roc_auc_score, average_precision_score, confusion_matrix

davis_pred_label = np.where(np.array(davis_pred) >= 0.5, 1, 0)
binding_pred_label = np.where(np.array(binding_pred) >= 0.5, 1, 0)
biosnap_pred_label = np.where(np.array(biosnap_pred) >= 0.5, 1, 0)

def compute_sen_spec(y_test, y_pred):
    tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
    
    sensitivity = (tp / (tp + fn)).round(4)
    specificity = (tn / (tn + fp)).round(4)

    return sensitivity, specificity

davis_auroc = roc_auc_score(davis_target, davis_pred).round(4)
davis_auprc = average_precision_score(davis_target, davis_pred).round(4)
davis_sen, davis_spec = compute_sen_spec(davis_target, davis_pred_label)

binding_auroc = roc_auc_score(binding_target, binding_pred).round(4)
binding_auprc = average_precision_score(binding_target, binding_pred).round(4)
binding_sen, binding_spec = compute_sen_spec(binding_target, binding_pred_label)

biosnap_auroc = roc_auc_score(biosnap_target, biosnap_pred).round(4)
biosnap_auprc = average_precision_score(biosnap_target, biosnap_pred).round(4)
biosnap_sen, biosnap_spec = compute_sen_spec(biosnap_target, biosnap_pred_label)

print(f"DAVIS\tAUROC:{davis_auroc}\tAUPRC:{davis_auprc}\tSen:{davis_sen}\tSpec:{davis_spec}")
print(f"Binding\tAUROC:{binding_auroc}\tAUPRC:{binding_auprc}\tSen:{binding_sen}\tSpec:{binding_spec}")
print(f"BIOSNAP\tAUROC:{biosnap_auroc}\tAUPRC:{biosnap_auprc}\tSen:{biosnap_sen}\tSpec:{biosnap_spec}")

DAVIS	AUROC:0.9582	AUPRC:0.9578	Sen:0.9492	Spec:0.7814
Binding	AUROC:0.9694	AUPRC:0.9688	Sen:0.9318	Spec:0.8688
BIOSNAP	AUROC:0.9714	AUPRC:0.9714	Sen:0.9294	Spec:0.8868
