In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import einsum
from einops import rearrange
from torch.utils.data import DataLoader, Dataset, RandomSampler

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics.functional import mean_squared_error, mean_absolute_error

from transformers import BertModel, BertTokenizer

import pickle
import pandas as pd
from tqdm import tqdm
from tdc.multi_pred import DTI

kiba = DTI(name="kiba")
kiba_split = kiba.get_split()

train_df = kiba_split['train']
valid_df = kiba_split['valid']
test_df = kiba_split['test']

with open("data/drug/kiba_drug_embeddings.pkl", "rb") as f:
    molecule_embeddings = pickle.load(f)
    
with open("data/target/kiba_target_embeddings.pkl", "rb") as f:
    protein_embeddings = pickle.load(f) 

Found local copy...
Loading...
Done!


In [2]:
class DTIDataset(Dataset):
    def __init__(self, data, molecule_embedding, protein_embedding):
        self.data = data
        
        self.molecule_embedding = molecule_embedding
        self.protein_embedding = protein_embedding
    
        
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        # original embedding
        molecule_sequence = self.molecule_embedding[self.data.loc[idx, "Drug_ID"]][1]
        # reduced embedding
        protein_sequence = self.protein_embedding[self.data.loc[idx, "Target_ID"]][2]
        y = self.data.loc[idx, "Y"]
                
        return molecule_sequence, protein_sequence, y

    
def collate_batch(batch):
    molecule_seq, protein_seq, y = [], [], []
    
    for (molecule_seq_, protein_seq_, y_) in batch:
        molecule_seq.append(molecule_seq_)
        protein_seq.append(protein_seq_)
        y.append(y_)
    
    molecule_seq = torch.nn.utils.rnn.pad_sequence([torch.Tensor(t) for t in molecule_seq], batch_first=True)
    protein_seq = torch.nn.utils.rnn.pad_sequence([torch.Tensor(t) for t in protein_seq], batch_first=True)
    y = torch.tensor(y).float()
    
    return molecule_seq, protein_seq, y


train_dataset = DTIDataset(train_df, molecule_embeddings['train'], protein_embeddings['train'])
valid_dataset = DTIDataset(valid_df, molecule_embeddings['valid'], protein_embeddings['valid'])
test_dataset = DTIDataset(test_df, molecule_embeddings['test'], protein_embeddings['test'])

batch_size = 512

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=12, 
                              pin_memory=True, prefetch_factor=10, 
                              drop_last=True, collate_fn=collate_batch, shuffle=True)

valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=12, 
                              pin_memory=True, prefetch_factor=10, 
                              collate_fn=collate_batch)

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=12, 
                             pin_memory=True, prefetch_factor=10, 
                             collate_fn=collate_batch)

In [3]:
class CrossAttention(nn.Module):
    def __init__(self, input_dim=128, intermediate_dim=512, heads=8, dropout=0.1):
        super().__init__()
        project_out = input_dim

        self.heads = heads
        self.scale = (input_dim / heads) ** -0.5

        self.key = nn.Linear(input_dim, intermediate_dim, bias=False)
        self.value = nn.Linear(input_dim, intermediate_dim, bias=False)
        self.query = nn.Linear(input_dim, intermediate_dim, bias=False)

        self.out = nn.Sequential(
            nn.Linear(intermediate_dim, project_out),
            nn.Dropout(dropout)
        )

        
    def forward(self, data):
        b, n, d, h = *data.shape, self.heads

        k = self.key(data)
        k = rearrange(k, 'b n (h d) -> b h n d', h=h)

        v = self.value(data)
        v = rearrange(v, 'b n (h d) -> b h n d', h=h)
        
        # get only cls token
        q = self.query(data[:, 0].unsqueeze(1))
        q = rearrange(q, 'b n (h d) -> b h n d', h=h)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attention = dots.softmax(dim=-1)

        output = einsum('b h i j, b h j d -> b h i d', attention, v)
        output = rearrange(output, 'b h n d -> b n (h d)')
        output = self.out(output)
        
        return output


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
        
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

    
class CrossAttentionLayer(nn.Module):
    def __init__(self, 
                 molecule_dim=128, molecule_intermediate_dim=256,
                 protein_dim=256, protein_intermediate_dim=512,
                 cross_attn_depth=1, cross_attn_heads=4, dropout=0.1):
        super().__init__()

        self.cross_attn_layers = nn.ModuleList([])
        
        for _ in range(cross_attn_depth):
            self.cross_attn_layers.append(nn.ModuleList([
                nn.Linear(molecule_dim, protein_dim),
                nn.Linear(protein_dim, molecule_dim),
                PreNorm(protein_dim, CrossAttention(
                    protein_dim, protein_intermediate_dim, cross_attn_heads, dropout
                )),
                nn.Linear(protein_dim, molecule_dim),
                nn.Linear(molecule_dim, protein_dim),
                PreNorm(molecule_dim, CrossAttention(
                    molecule_dim, molecule_intermediate_dim, cross_attn_heads, dropout
                ))
            ]))

            
    def forward(self, molecule, protein):
        for i, (f_sl, g_ls, cross_attn_s, f_ls, g_sl, cross_attn_l) in enumerate(self.cross_attn_layers):
            
            cls_molecule = molecule[:, 0]
            x_molecule = molecule[:, 1:]
            
            cls_protein = protein[:, 0]
            x_protein = protein[:, 1:]

            # Cross attention for protein sequence
            cal_q = f_ls(cls_protein.unsqueeze(1))
            cal_qkv = torch.cat((cal_q, x_molecule), dim=1)
            # add activation function
            cal_out = cal_q + cross_attn_l(cal_qkv)
            cal_out = F.gelu(g_sl(cal_out))
            protein = torch.cat((cal_out, x_protein), dim=1)

            # Cross attention for molecule sequence
            cal_q = f_sl(cls_molecule.unsqueeze(1))
            cal_qkv = torch.cat((cal_q, x_protein), dim=1)
            # add activation function
            cal_out = cal_q + cross_attn_s(cal_qkv)
            cal_out = F.gelu(g_ls(cal_out))
            molecule = torch.cat((cal_out, x_molecule), dim=1)
            
        return molecule, protein
    
    
class AttentionalDTI(nn.Module):
    def __init__(self, 
                 cross_attention_layer, 
                 molecule_input_dim=128, protein_input_dim=256, hidden_dim=512, **kwargs):
        super().__init__()
        
        self.cross_attention_layer = cross_attention_layer
        
        self.molecule_mlp = nn.Sequential(
            nn.LayerNorm(molecule_input_dim),
            nn.Linear(molecule_input_dim, hidden_dim)
        )
        
        self.protein_mlp = nn.Sequential(
            nn.LayerNorm(protein_input_dim),
            nn.Linear(protein_input_dim, hidden_dim)
        )
        
        self.fc_out = nn.Linear(hidden_dim, 1)
        
    
    def forward(self, molecule_seq, protein_seq):
        molecule_out, protein_out = self.cross_attention_layer(molecule_seq, protein_seq)
        
        molecule_out = molecule_out[:, 0]
        protein_out = protein_out[:, 0]
        
        # cls token
        molecule_projected = self.molecule_mlp(molecule_out)
        protein_projected = self.protein_mlp(protein_out)
        
        out = self.fc_out(molecule_projected + protein_projected)
        
        return out

cross_attention_layer = CrossAttentionLayer()
attentional_dti = AttentionalDTI(cross_attention_layer, cross_attn_depth=8)

In [4]:
class DTI_prediction(pl.LightningModule):
    def __init__(self, attentional_dti):
        super().__init__()
        self.model = attentional_dti

        
    def forward(self, molecule_sequence, protein_sequence):
        return self.model(molecule_sequence, protein_sequence)
    
    
    def training_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def test_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        
        return y_hat

    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25)
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    
    
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=5, dirpath='weights/Attentional_DTI_cross_attention_kiba_reduced_embedding', filename='attentional_dti-{epoch:03d}-{valid_loss:.4f}-{valid_mae:.4f}'),
]

model = DTI_prediction(attentional_dti)

trainer = pl.Trainer(max_epochs=500, gpus=1, enable_progress_bar=True, callbacks=callbacks)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_dataloader, valid_dataloader)

Missing logger folder: /home/ubuntu/Workspace/Attentional_DTI/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type           | Params
-----------------------------------------
0 | model | AttentionalDTI | 987 K 
-----------------------------------------
987 K     Trainable params
0         Non-trainable params
987 K     Total params
3.949     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

In [12]:
ckpt_fname = "attentional_dti-epoch=261-valid_loss=0.1651-valid_mae=0.2343.ckpt"

model = model.load_from_checkpoint("weights/Attentional_DTI_cross_attention_kiba/" + ckpt_fname, attentional_dti=attentional_dti)
trainer.test(model, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.16898369789123535, 'test_mae': 0.23121696710586548}
--------------------------------------------------------------------------------


[{'test_loss': 0.16898369789123535, 'test_mae': 0.23121696710586548}]

In [13]:
pred = trainer.predict(model, test_dataloader)

true_, pred_ = test_dataloader.dataset.data.Y, []

for i in pred:
    for j in i.tolist():
        pred_.append(j)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 4118it [00:00, ?it/s]

In [14]:
def get_cindex(Y, P):
    summ = 0
    pair = 0
    
    for i in tqdm(range(1, len(Y))):
        for j in range(0, i):
            if i is not j:
                if(Y[i] > Y[j]):
                    pair +=1
                    summ +=  1* (P[i] > P[j]) + 0.5 * (P[i] == P[j])
        
            
    if pair is not 0:
        return summ/pair
    else:
        return 0


def r_squared_error(y_obs,y_pred):
    y_obs = np.array(y_obs)
    y_pred = np.array(y_pred)
    y_obs_mean = [np.mean(y_obs) for y in y_obs]
    y_pred_mean = [np.mean(y_pred) for y in y_pred]

    mult = sum((y_pred - y_pred_mean) * (y_obs - y_obs_mean))
    mult = mult * mult

    y_obs_sq = sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))
    y_pred_sq = sum((y_pred - y_pred_mean) * (y_pred - y_pred_mean) )

    return mult / float(y_obs_sq * y_pred_sq)


def get_k(y_obs,y_pred):
    y_obs = np.array(y_obs)
    y_pred = np.array(y_pred)

    return sum(y_obs*y_pred) / float(sum(y_pred*y_pred))


def squared_error_zero(y_obs,y_pred):
    k = get_k(y_obs,y_pred)

    y_obs = np.array(y_obs)
    y_pred = np.array(y_pred)
    y_obs_mean = [np.mean(y_obs) for y in y_obs]
    upp = sum((y_obs - (k*y_pred)) * (y_obs - (k* y_pred)))
    down= sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))

    return 1 - (upp / float(down))


def get_rm2(ys_orig, ys_line):
    r2 = r_squared_error(ys_orig, ys_line)
    r02 = squared_error_zero(ys_orig, ys_line)

    return r2 * (1 - np.sqrt(np.absolute((r2*r2)-(r02*r02))))

get_cindex(true_, pred_)

100%|█████████████████████████████████████| 23530/23530 [27:19<00:00, 14.35it/s]


0.8762425728288034

In [15]:
import numpy as np

get_rm2(true_, pred_)

0.686873981628849

In [16]:
from sklearn.metrics import average_precision_score

true_label = [1 if t >= 12.1 else 0 for t in true_]

average_precision_score(true_label, pred_)

0.8208179169928685