In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import einsum
from einops import rearrange
from torch.utils.data import DataLoader, Dataset, RandomSampler

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics.functional import mean_squared_error, mean_absolute_error

from transformers import BertModel, BertTokenizer

import pickle 
import numpy as np
import pandas as pd
from tqdm import tqdm
from tdc.multi_pred import DTI
from sklearn.decomposition import TruncatedSVD


def molecule_encode(molecule_sequence):
        molecule_sequence = molecule_tokenizer(
            " ".join(molecule_sequence), 
            max_length=128, 
            truncation=True,
            return_tensors="pt"
        )
        
        return molecule_sequence
    
def protein_encode(protein_sequence):
    protein_sequence = protein_tokenizer(
        " ".join(protein_sequence), 
        max_length=1024, 
        truncation=True, 
        return_tensors="pt"
    )

    return protein_sequence


molecule_bert = BertModel.from_pretrained("weights/molecule_bert").to("cuda")
protein_bert = BertModel.from_pretrained("weights/protein_bert").to("cuda")

molecule_bert.eval()
protein_bert.eval()

molecule_tokenizer = molecule_tokenizer = BertTokenizer.from_pretrained("data/drug/molecule_tokenizer", model_max_length=128)
protein_tokenizer = BertTokenizer.from_pretrained("data/target/protein_tokenizer", do_lower_case=False)


Some weights of BertModel were not initialized from the model checkpoint at weights/molecule_bert and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertModel were not initialized from the model checkpoint at weights/protein_bert and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# DAVIS

In [2]:
davis = DTI(name="davis")
davis.convert_to_log(form = 'binding')
davis_split = davis.get_split()

train_df = davis_split['train']
valid_df = davis_split['valid']
test_df = davis_split['test']

encoded_protein = protein_encode(train_df.Target.unique()[0]).to("cuda")
encoded_protein = protein_bert(**encoded_protein)

results = encoded_protein.last_hidden_state.squeeze(0).detach().cpu().numpy()

for target in tqdm(train_df.Target.unique()[1:]):
    encoded_protein = protein_encode(target).to("cuda")
    encoded_protein = protein_bert(**encoded_protein)
    result = encoded_protein.last_hidden_state.squeeze(0).detach().cpu().numpy()
    results = np.vstack((results, result))
    
svd = TruncatedSVD(n_components=256, random_state=42).fit(results)

Found local copy...
Loading...
Done!
To log space...
100%|█████████████████████████████████████████| 378/378 [02:06<00:00,  3.00it/s]


## Protein embedding - DAVIS

In [4]:
total_embeddings = {}

for mode in ["train", "valid", "test"]:
    if mode == "train":
        unique_df = train_df.loc[:, ["Target_ID", "Target"]].drop_duplicates().reset_index(drop=True)
    elif mode == "valid":
        unique_df = valid_df.loc[:, ["Target_ID", "Target"]].drop_duplicates().reset_index(drop=True)
    else:
        unique_df = test_df.loc[:, ["Target_ID", "Target"]].drop_duplicates().reset_index(drop=True)
    
    print("Mode:", mode)
    results = {}
    
    for i, line in tqdm(unique_df.iterrows(), total=len(unique_df)):
        encoded_protein = protein_encode(line['Target']).to("cuda")
        encoded_protein = protein_bert(**encoded_protein)

        protein_cls = encoded_protein.pooler_output.detach().cpu().numpy()
        protein_emb = encoded_protein.last_hidden_state[0].detach().cpu().numpy()
        reduced_emb = svd.transform(protein_emb)

        embeddings = [protein_cls, protein_emb, reduced_emb]
        results[line["Target_ID"]] = embeddings
    
    total_embeddings[mode] = results
    
with open("data/target/davis_target_embeddings.pkl", "wb") as f:
    pickle.dump(total_embeddings, f)

Mode: train


100%|█████████████████████████████████████████| 379/379 [00:14<00:00, 25.61it/s]


Mode: valid


100%|█████████████████████████████████████████| 378/378 [00:14<00:00, 25.94it/s]


Mode: test


100%|█████████████████████████████████████████| 379/379 [00:14<00:00, 26.20it/s]


## Molecule embedding - DAVIS

In [5]:
total_embeddings = {}

for mode in ["train", "valid", "test"]:
    if mode == "train":
        unique_df = train_df.loc[:, ["Drug_ID", "Drug"]].drop_duplicates().reset_index(drop=True)
    elif mode == "valid":
        unique_df = valid_df.loc[:, ["Drug_ID", "Drug"]].drop_duplicates().reset_index(drop=True)
    else:
        unique_df = test_df.loc[:, ["Drug_ID", "Drug"]].drop_duplicates().reset_index(drop=True)
    
    print("Mode:", mode)
    results = {}
    
    for i, line in tqdm(unique_df.iterrows(), total=len(unique_df)):
        encoded_molecule = molecule_encode(line['Drug']).to("cuda")
        encoded_molecule = molecule_bert(**encoded_molecule)

        molecule_cls = encoded_molecule.pooler_output.detach().cpu().numpy()
        moelcule_emb = encoded_molecule.last_hidden_state[0].detach().cpu().numpy()

        embeddings = [molecule_cls, moelcule_emb]
        results[line["Drug_ID"]] = embeddings
    
    total_embeddings[mode] = results
    
with open("data/drug/davis_drug_embeddings.pkl", "wb") as f:
    pickle.dump(total_embeddings, f)

Mode: train


100%|██████████████████████████████████████████| 68/68 [00:00<00:00, 117.41it/s]


Mode: valid


100%|██████████████████████████████████████████| 68/68 [00:00<00:00, 117.60it/s]


Mode: test


100%|██████████████████████████████████████████| 68/68 [00:00<00:00, 118.13it/s]


# KIBA

In [10]:
kiba = DTI(name="kiba")
kiba_split = kiba.get_split()

train_df = kiba_split['train']
valid_df = kiba_split['valid']
test_df = kiba_split['test']

encoded_protein = protein_encode(train_df.Target.unique()[0]).to("cuda")
encoded_protein = protein_bert(**encoded_protein)

results = encoded_protein.last_hidden_state.squeeze(0).detach().cpu().numpy()

for target in tqdm(train_df.Target.unique()[1:]):
    encoded_protein = protein_encode(target).to("cuda")
    encoded_protein = protein_bert(**encoded_protein)
    result = encoded_protein.last_hidden_state.squeeze(0).detach().cpu().numpy()
    results = np.vstack((results, result))

svd = TruncatedSVD(n_components=256, random_state=42).fit(results)

Found local copy...
Loading...
Done!
100%|█████████████████████████████████████████| 228/228 [00:31<00:00,  7.26it/s]


## Protein embedding - KIBA

In [11]:
total_embeddings = {}

for mode in ["train", "valid", "test"]:
    if mode == "train":
        unique_df = train_df.loc[:, ["Target_ID", "Target"]].drop_duplicates().reset_index(drop=True)
    elif mode == "valid":
        unique_df = valid_df.loc[:, ["Target_ID", "Target"]].drop_duplicates().reset_index(drop=True)
    else:
        unique_df = test_df.loc[:, ["Target_ID", "Target"]].drop_duplicates().reset_index(drop=True)
    
    print("Mode:", mode)
    results = {}
    
    for i, line in tqdm(unique_df.iterrows(), total=len(unique_df)):
        encoded_protein = protein_encode(line['Target']).to("cuda")
        encoded_protein = protein_bert(**encoded_protein)

        protein_cls = encoded_protein.pooler_output.detach().cpu().numpy()
        protein_emb = encoded_protein.last_hidden_state[0].detach().cpu().numpy()
        reduced_emb = svd.transform(protein_emb)

        embeddings = [protein_cls, protein_emb, reduced_emb]
        results[line["Target_ID"]] = embeddings
    
    total_embeddings[mode] = results
    
with open("data/target/kiba_target_embeddings.pkl", "wb") as f:
    pickle.dump(total_embeddings, f)

Mode: train


100%|█████████████████████████████████████████| 229/229 [00:09<00:00, 25.03it/s]


Mode: valid


100%|█████████████████████████████████████████| 226/226 [00:07<00:00, 29.13it/s]


Mode: test


100%|█████████████████████████████████████████| 228/228 [00:08<00:00, 27.29it/s]


## Molecule embedding - KIBA

In [12]:
total_embeddings = {}

for mode in ["train", "valid", "test"]:
    if mode == "train":
        unique_df = train_df.loc[:, ["Drug_ID", "Drug"]].drop_duplicates().reset_index(drop=True)
    elif mode == "valid":
        unique_df = valid_df.loc[:, ["Drug_ID", "Drug"]].drop_duplicates().reset_index(drop=True)
    else:
        unique_df = test_df.loc[:, ["Drug_ID", "Drug"]].drop_duplicates().reset_index(drop=True)
    
    print("Mode:", mode)
    results = {}
    
    for i, line in tqdm(unique_df.iterrows(), total=len(unique_df)):
        encoded_molecule = molecule_encode(line['Drug']).to("cuda")
        encoded_molecule = molecule_bert(**encoded_molecule)

        molecule_cls = encoded_molecule.pooler_output.detach().cpu().numpy()
        moelcule_emb = encoded_molecule.last_hidden_state[0].detach().cpu().numpy()

        embeddings = [molecule_cls, moelcule_emb]
        results[line["Drug_ID"]] = embeddings
    
    total_embeddings[mode] = results
    
with open("data/drug/kiba_drug_embeddings.pkl", "wb") as f:
    pickle.dump(total_embeddings, f)

Mode: train


100%|██████████████████████████████████████| 2068/2068 [00:16<00:00, 125.46it/s]


Mode: valid


100%|██████████████████████████████████████| 1850/1850 [00:15<00:00, 119.12it/s]


Mode: test


100%|██████████████████████████████████████| 2021/2021 [00:16<00:00, 120.53it/s]
