In [1]:
PROJECT_NAME = "CrossViT_MSE_half_freeze"
LEARNING_RATE = 1e-5
PROT_MAX_LEN = 545

import pickle
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from einops import rearrange
from torch.utils.data import DataLoader, Dataset
from torchmetrics.functional import average_precision
from torchmetrics.functional.classification import binary_auroc

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(name=f'{PROJECT_NAME}_lr-{LEARNING_RATE}_prot_{PROT_MAX_LEN}',
                           project='DLM_DTI')

import transformers
from transformers import AutoModel, BertTokenizer, RobertaTokenizer

train_data = pd.read_csv("data/mol_trans/train_dataset.csv")
valid_data = pd.read_csv("data/mol_trans/valid_dataset.csv")
test_data = pd.read_csv("data/mol_trans/test_dataset.csv")
    
mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
mol_encoder = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
prot_encoder = AutoModel.from_pretrained("Rostlab/prot_bert")

for param in prot_encoder.embeddings.parameters():
    param.requires_grad = False

for layer in prot_encoder.encoder.layer[:16]:
    for param in layer.parameters():
        param.requires_grad = False

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjonghyunlee1993[0m. Use [1m`wandb login --relogin`[0m to force relogin


Some weights of the model checkpoint at seyonec/ChemBERTa-zinc-base-v1 were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias

In [2]:
from typing import Callable

class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
    """Samples elements randomly from a given list of indices for imbalanced dataset
    Arguments:
        indices: a list of indices
        num_samples: number of samples to draw
        callback_get_label: a callback-like function which takes two arguments - dataset and index
    """

    def __init__(
        self,
        dataset,
        labels=None,
        indices=None,
        num_samples=None,
        callback_get_label=None,
    ):
        # if indices is not provided, all elements in the dataset will be considered
        self.indices = list(range(len(dataset))) if indices is None else indices

        # define custom callback
        self.callback_get_label = dataset.data.Label

        # if num_samples is not provided, draw `len(indices)` samples in each iteration
        self.num_samples = len(self.indices) if num_samples is None else num_samples

        # distribution of classes in the dataset
        df = pd.DataFrame()
        df["Label"] = self._get_labels(dataset) if labels is None else labels
        df.index = self.indices
        df = df.sort_index()

        label_to_count = df["Label"].value_counts()

        weights = 1.0 / label_to_count[df["Label"]]

        self.weights = torch.DoubleTensor(weights.to_list())

    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, replacement=True))

    def __len__(self):
        return self.num_samples
    

class DTIDataset(Dataset):
    def __init__(self, data, mol_tokenizer, prot_tokenizer):
        self.data = data
        self.mol_tokenizer = mol_tokenizer
        self.prot_tokenizer = prot_tokenizer
        
    def get_mol_feature(self, smiles):
        return self.mol_tokenizer(smiles, max_length=512, truncation=True)
    
    def get_prot_feature(self, fasta):
        return self.prot_tokenizer(" ".join(fasta), max_length=PROT_MAX_LEN, truncation=True)
    
    def __len__(self):    
        return len(self.data)
    
    def __getitem__(self, index):
        smiles = self.data.loc[index, "SMILES"]
        mol_feature = self.get_mol_feature(smiles)
        
        fasta = self.data.loc[index, "Target Sequence"]
        prot_feature = self.get_prot_feature(fasta)
        
        y = self.data.loc[index, "Label"]
        source = self.data.loc[index, "Source"]
                
        return mol_feature, prot_feature, y, source
    
def collate_batch(batch):
    mol_features, prot_features, y, source = [], [], [], []
    
    for (mol_seq, prot_seq, y_, source_) in batch:
        mol_features.append(mol_seq)
        prot_features.append(prot_seq)
        y.append(y_)
        source.append(source_)
        
    mol_features = mol_tokenizer.pad(mol_features, return_tensors="pt")
    prot_features = prot_tokenizer.pad(prot_features, return_tensors="pt")
    y = torch.tensor(y).float()
    source = torch.tensor(source)
    
    return mol_features, prot_features, y, source


train_dataset = DTIDataset(train_data, mol_tokenizer, prot_tokenizer)
valid_dataset = DTIDataset(valid_data, mol_tokenizer, prot_tokenizer)
test_dataset = DTIDataset(test_data, mol_tokenizer, prot_tokenizer)


train_dataset = DTIDataset(train_data, mol_tokenizer, prot_tokenizer)
valid_dataset = DTIDataset(valid_data, mol_tokenizer, prot_tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=32, num_workers=16, 
                              pin_memory=True, prefetch_factor=10, drop_last=True, 
                              sampler=ImbalancedDatasetSampler(train_dataset, labels=train_dataset.data.Label),
                              collate_fn=collate_batch)

valid_dataloader = DataLoader(valid_dataset, batch_size=32, num_workers=16, 
                              pin_memory=True, prefetch_factor=10,
                              collate_fn=collate_batch)

test_dataloader = DataLoader(test_dataset, batch_size=32, num_workers=16, 
                             pin_memory=True, prefetch_factor=10,
                             collate_fn=collate_batch)

In [3]:
class CrossAttention(nn.Module):
    def __init__(self, input_dim=128, intermediate_dim=512, heads=8, dropout=0.1):
        super().__init__()
        project_out = input_dim

        self.heads = heads
        self.scale = (input_dim / heads) ** -0.5

        self.key = nn.Linear(input_dim, intermediate_dim, bias=False)
        self.value = nn.Linear(input_dim, intermediate_dim, bias=False)
        self.query = nn.Linear(input_dim, intermediate_dim, bias=False)

        self.out = nn.Sequential(
            nn.Linear(intermediate_dim, project_out),
            nn.Dropout(dropout)
        )

        
    def forward(self, data):
        b, n, d, h = *data.shape, self.heads

        k = self.key(data)
        k = rearrange(k, 'b n (h d) -> b h n d', h=h)

        v = self.value(data)
        v = rearrange(v, 'b n (h d) -> b h n d', h=h)
        
        # get only cls token
        q = self.query(data[:, 0].unsqueeze(1))
        q = rearrange(q, 'b n (h d) -> b h n d', h=h)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attention = dots.softmax(dim=-1)

        output = einsum('b h i j, b h j d -> b h i d', attention, v)
        output = rearrange(output, 'b h n d -> b n (h d)')
        output = self.out(output)
        
        return output


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
        
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

    
class CrossAttentionLayer(nn.Module):
    def __init__(self, 
                 molecule_dim=768, molecule_intermediate_dim=1024,
                 protein_dim=1024, protein_intermediate_dim=2048,
                 cross_attn_depth=1, cross_attn_heads=4, dropout=0.1):
        super().__init__()

        self.cross_attn_layers = nn.ModuleList([])
        
        for _ in range(cross_attn_depth):
            self.cross_attn_layers.append(nn.ModuleList([
                nn.Linear(molecule_dim, protein_dim),
                nn.Linear(protein_dim, molecule_dim),
                PreNorm(protein_dim, CrossAttention(
                    protein_dim, protein_intermediate_dim, cross_attn_heads, dropout
                )),
                nn.Linear(protein_dim, molecule_dim),
                nn.Linear(molecule_dim, protein_dim),
                PreNorm(molecule_dim, CrossAttention(
                    molecule_dim, molecule_intermediate_dim, cross_attn_heads, dropout
                ))
            ]))

            
    def forward(self, molecule, protein):
        for i, (f_sl, g_ls, cross_attn_s, f_ls, g_sl, cross_attn_l) in enumerate(self.cross_attn_layers):
            
            cls_molecule = molecule[:, 0]
            x_molecule = molecule[:, 1:]
            
            cls_protein = protein[:, 0]
            x_protein = protein[:, 1:]

            # Cross attention for protein sequence
            cal_q = f_ls(cls_protein.unsqueeze(1))
            cal_qkv = torch.cat((cal_q, x_molecule), dim=1)
            # add activation function
            cal_out = cal_q + cross_attn_l(cal_qkv)
            cal_out = F.gelu(g_sl(cal_out))
            protein = torch.cat((cal_out, x_protein), dim=1)

            # Cross attention for molecule sequence
            cal_q = f_sl(cls_molecule.unsqueeze(1))
            cal_qkv = torch.cat((cal_q, x_protein), dim=1)
            # add activation function
            cal_out = cal_q + cross_attn_s(cal_qkv)
            cal_out = F.gelu(g_ls(cal_out))
            molecule = torch.cat((cal_out, x_molecule), dim=1)
            
        return molecule, protein
    

class DTI(nn.Module):
    def __init__(self, mol_encoder, prot_encoder, 
                 hidden_dim=512, mol_dim=128, prot_dim=1024):
        super().__init__()
        self.mol_encoder = mol_encoder
        self.prot_encoder = prot_encoder
        
        self.cross_attention = CrossAttentionLayer(cross_attn_depth=2, cross_attn_heads=4)
        
        self.molecule_align = nn.Sequential(
            nn.LayerNorm(mol_dim),
            nn.Linear(mol_dim, hidden_dim, bias=False)
        )
        
        self.protein_align = nn.Sequential(
            nn.LayerNorm(prot_dim),
            nn.Linear(prot_dim, hidden_dim, bias=False)
        )       
        
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim * 4)
        self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim * 2)
        self.fc3 = nn.Linear(hidden_dim * 2, hidden_dim)
        
        self.cls_out = nn.Linear(hidden_dim, 1)

    
    def forward(self, SMILES, FASTA):
        mol_feature = self.mol_encoder(**SMILES).last_hidden_state
        prot_feature = self.prot_encoder(**FASTA).last_hidden_state
        
        mol_feature, prot_feature = self.cross_attention(mol_feature, prot_feature)
        mol_feature = mol_feature[:, 0]
        prot_feature = prot_feature[:, 0]
        
        mol_feature = self.molecule_align(mol_feature)
        prot_feature = self.protein_align(prot_feature)
        
        x = torch.cat([mol_feature, prot_feature], dim=1)

        x = F.dropout(F.gelu(self.fc1(x)), 0.1)
        x = F.dropout(F.gelu(self.fc2(x)), 0.1)
        x = F.dropout(F.gelu(self.fc3(x)), 0.1)
        
        cls_out = self.cls_out(x).squeeze(-1)
        
        return F.tanh(cls_out)
    
model = DTI(mol_encoder, prot_encoder,
            hidden_dim=512, mol_dim=768, prot_dim=1024)


In [4]:
class DTI_prediction(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    
    def step(self, batch):
        mol_feature, prot_feature, y, source = batch
        pred = self.model(mol_feature, prot_feature).squeeze(-1)
        
#         loss = F.binary_cross_entropy_with_logits(pred, y)
        loss = F.smooth_l1_loss(pred, y)
    
        auroc = binary_auroc(pred, y)
        auprc = average_precision(pred, y)
        
        return pred, source, loss, auroc, auprc, 
    
    
    def training_step(self, batch, batch_idx):
        _, _, loss, auroc, auprc = self.step(batch)
        
        self.log('train_auroc', auroc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_auprc', auprc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        _, _, loss, auroc, auprc = self.step(batch)
        
        self.log('valid_auroc', auroc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('valid_auprc', auprc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    
    
    def test_step(self, batch, batch_idx):
        _, _, loss, auroc, auprc = self.step(batch)
        
        self.log('test_auroc', auroc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_auprc', auprc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        pred, source, _, _, _ = self.step(batch)
        
        return pred, batch[2], source

    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=LEARNING_RATE)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
          
    
callbacks = [
    ModelCheckpoint(monitor='valid_auroc', mode="max",
                    save_top_k=5, dirpath=f'weights/{PROJECT_NAME}', filename='DTI-{epoch:03d}-{valid_loss:.4f}-{valid_auroc:.4f}-{valid_auprc:.4f}'),
]

predictor = DTI_prediction(model)
trainer = pl.Trainer(max_epochs=200, gpus=[1], enable_progress_bar=True, 
                     callbacks=callbacks, logger=wandb_logger, precision=16)

  rank_zero_deprecation(
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(predictor, train_dataloader, valid_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type | Params
-------------------------------
0 | model | DTI  | 499 M 
-------------------------------
256 M     Trainable params
242 M     Non-trainable params
499 M     Total params
998.114   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [6]:
predictor = predictor.load_from_checkpoint(
    "weights/CrossViT_MSE//DTI-epoch=036-valid_loss=0.0703-valid_auroc=0.9005-valid_auprc=nan.ckpt",
    model=model
)

pred_out = trainer.predict(predictor, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

In [10]:
results = np.array([]).reshape(0, 3)

for pred, label, source in pred_out:
    line_ = np.array([pred.detach().numpy(), 
         label.detach().numpy(), 
         source.detach().numpy()]).T
    results = np.vstack([results, line_])

results = pd.DataFrame(results, columns=["pred", "y", "source"])
results.y = results.y.astype(int)
results.source = results.source.astype(int)
results.loc[results.pred < 0, "pred"] = 0
results.head()

Unnamed: 0,pred,y,source
0,0.0,0,0
1,0.0,0,0
2,0.977051,0,0
3,0.346191,0,0
4,0.996094,1,0


In [11]:
def get_evaluation_metrics(df, source=0):
    from sklearn.metrics import average_precision_score, roc_auc_score
    
    def get_cindex(Y, P):
        summ = 0
        pair = 0

        for i in range(1, len(Y)):
            for j in range(0, i):
                if i is not j:
                    if(Y[i] > Y[j]):
                        pair +=1
                        summ +=  1* (P[i] > P[j]) + 0.5 * (P[i] == P[j])

        if pair is not 0:
            return summ/pair
        else:
            return 0

    def r_squared_error(y_obs,y_pred):
        y_obs = np.array(y_obs)
        y_pred = np.array(y_pred)
        y_obs_mean = [np.mean(y_obs) for y in y_obs]
        y_pred_mean = [np.mean(y_pred) for y in y_pred]

        mult = sum((y_pred - y_pred_mean) * (y_obs - y_obs_mean))
        mult = mult * mult

        y_obs_sq = sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))
        y_pred_sq = sum((y_pred - y_pred_mean) * (y_pred - y_pred_mean) )

        return mult / float(y_obs_sq * y_pred_sq)

    def get_k(y_obs,y_pred):
        y_obs = np.array(y_obs)
        y_pred = np.array(y_pred)

        return sum(y_obs*y_pred) / float(sum(y_pred*y_pred))

    def squared_error_zero(y_obs,y_pred):
        k = get_k(y_obs,y_pred)

        y_obs = np.array(y_obs)
        y_pred = np.array(y_pred)
        y_obs_mean = [np.mean(y_obs) for y in y_obs]
        upp = sum((y_obs - (k*y_pred)) * (y_obs - (k* y_pred)))
        down= sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))

        return 1 - (upp / float(down))


    def get_rm2(ys_orig, ys_line):
        r2 = r_squared_error(ys_orig, ys_line)
        r02 = squared_error_zero(ys_orig, ys_line)

        return r2 * (1 - np.sqrt(np.absolute((r2*r2)-(r02*r02))))
    
    
    source_df = df[df.source == source].reset_index(drop=True)
    auroc = roc_auc_score(source_df.y, source_df.pred)
    auprc = average_precision_score(source_df.y, source_df.pred)

    
    if source == 0:
        dataset = "Davis"
    elif source == 1:
        dataset = "BindingDB"
    elif source == 2:
        dataset = "BIOSNAP"
        
    print(f"Dataset: {dataset}")
    print("AUROC\tAUPRC")
    print(f"{auroc.round(4)}\t{auprc.round(4)}")
    print()
    
get_evaluation_metrics(results, source=0)
get_evaluation_metrics(results, source=1)
get_evaluation_metrics(results, source=2)

Dataset: Davis
AUROC	AUPRC
0.8962	0.3446

Dataset: BindingDB
AUROC	AUPRC
0.884	0.8712

Dataset: BIOSNAP
AUROC	AUPRC
0.9077	0.5844

