In [1]:
import glob
import torch
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from tdc.multi_pred import DTI
from transformers import AutoTokenizer, AutoModel

mol_tokenizer = AutoTokenizer.from_pretrained("jonghyunlee/DrugLikeMoleculeBERT")
mol_encoder = AutoModel.from_pretrained("jonghyunlee/DrugLikeMoleculeBERT")
mol_encoder.to("cuda")
mol_encoder.eval()

prot_tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
prot_encoder = AutoModel.from_pretrained("Rostlab/prot_bert")
prot_encoder.to("cuda")
prot_encoder.eval()

print()

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).





In [2]:
def preproc_kiba():
    kiba = DTI(name="kiba").get_split()

    def gen_label(sub_df):
        sub_df.loc[:, "Label"] = sub_df.Y.map(lambda x: 1 if x >= 12.1 else 0)
        sub_df = sub_df.loc[:, ["Drug", "Target", "Label", "Y"]]
        sub_df.columns = ["SMILES", "Target Sequence", "Label", "Y"]
        return sub_df
    
    kiba_train = gen_label(kiba['train'])
    kiba_valid = gen_label(kiba['train'])
    kiba_test = gen_label(kiba['train'])

    kiba_train.to_csv("data/mol_trans/KIBA_train.csv", index=False)
    kiba_valid.to_csv("data/mol_trans/KIBA_valid.csv", index=False)
    kiba_test.to_csv("data/mol_trans/KIBA_test.csv", index=False)
    
# preproc_kiba()

In [3]:
def get_unique(fname_filter):
    flist = glob.glob(fname_filter)
    df = pd.DataFrame()
    
    for f in flist:
        df_ = pd.read_csv(f)
        df = df.append(df_)
    
    mols = df["SMILES"].drop_duplicates().reset_index(drop=True)
    prots = df["Target Sequence"].drop_duplicates().reset_index(drop=True)
    
    return mols, prots

davis = "data/mol_trans/DAVIS*.csv"
davis_mols, davis_prots = get_unique(davis)

binding = "data/mol_trans/BindingDB*.csv"
binding_mols, binding_prots = get_unique(binding)

# kiba = "data/mol_trans/KIBA*.csv"
# kiba_mols, kiba_prots = get_unique(kiba)

biosnap = "data/mol_trans/BIOSNAP*.csv"
biosnap_mols, biosnap_prots = get_unique(biosnap)

In [21]:
prots = np.array([])

for prot in [davis_prots, binding_prots, biosnap_prots]:
    prots = np.append(prots, prot)
    
prots = np.unique(prots)
prots = pd.DataFrame(prots, columns=['FASTA'])
prots.iloc[-1, 0] = prots.iloc[-1, 0].upper()
prots.to_csv("data/mol_trans/protein_sequences.csv", index=False)

In [12]:
def get_embeddings(df, fname, outfname, mode="mol"):
    cls_feature_dict = {}
    full_feature_dict = {}
    
    print(fname)
    for seq in tqdm(df, total=len(df)):
        if mode == "mol":
            name = seq[:10]
            X = mol_tokenizer.encode_plus(" ".join(seq) + " [PAD]" * (128-len(seq)), 
                                          return_tensors="pt", max_length=128, truncation=True)
            output = mol_encoder(**X.to("cuda"))
        elif mode == "prot":
            name = seq[:30]
            X = prot_tokenizer.encode_plus(" ".join(seq) + " [PAD]" * (768-len(seq)), 
                                           return_tensors="pt", max_length=768, truncation=True)
            output = prot_encoder(**X.to("cuda"))
        
        cls_feature_dict[name] = output[1].detach().to("cpu")
        full_feature_dict[name] = output[0].detach().to("cpu")
    
    with open("data/mol_trans/" + fname + "_" + outfname + "_cls.pkl", "wb") as f:
        pickle.dump(cls_feature_dict, f)
    
    with open("data/mol_trans/" + fname + "_" + outfname + "_full.pkl", "wb") as f:
        pickle.dump(full_feature_dict, f)
        
        
# get_embeddings(davis_mols, "davis_mols", "mol")
# get_embeddings(kiba_mols, "kiba_mols", "mol")
# get_embeddings(binding_mols, "binding_mols", "mol")
# get_embeddings(biosnap_mols, "biosnap_mols", "mol")

get_embeddings(davis_prots, "davis_prots", outfname="768", mode="prot")
# get_embeddings(kiba_prots, "kiba_prots", outfname="512", mode="prot")
get_embeddings(binding_prots, "binding_prots", outfname="768", mode="prot")
get_embeddings(biosnap_prots, "biosnap_prots", outfname="768", mode="prot")

davis_prots


100%|█████████████████████████████████████████| 379/379 [00:22<00:00, 17.14it/s]


binding_prots


100%|███████████████████████████████████████| 1254/1254 [01:14<00:00, 16.80it/s]


biosnap_prots


100%|███████████████████████████████████████| 2180/2180 [02:11<00:00, 16.58it/s]


In [13]:
def generate_merged_dict(flist, fname):
    merged_dict = {}
    
    for file_name in flist:
        with open(file_name, "rb") as f:
            sub_dict = pickle.load(f)
            
        merged_dict.update(sub_dict)
        
    with open("data/mol_trans/" + fname + ".pkl", "wb") as f:
        pickle.dump(merged_dict, f)

# mols_cls_list = glob.glob("data/mol_trans/*_mols_cls.pkl")
# mols_full_list = glob.glob("data/mol_trans/*_mols_full.pkl")
prots_cls_list = glob.glob("data/mol_trans/*_prots_768_cls.pkl")
prots_full_list = glob.glob("data/mol_trans/*_prots_768_full.pkl")

# generate_merged_dict(mols_cls_list, "mols_cls")
# generate_merged_dict(mols_full_list, "mols_full")
generate_merged_dict(prots_cls_list, "prots_768_cls")
generate_merged_dict(prots_full_list, "prots_768_full")

In [5]:
def merge_dataset(flist, mode):
    df = pd.DataFrame()
    for i, f in enumerate(flist):
        df_ = pd.read_csv(f)
        df_.loc[:, "Source"] = i
        
        df = df.append(df_)
    df.to_csv(f"data/mol_trans/{mode}_dataset.csv", index=False)
    
merge_dataset(glob.glob("data/mol_trans/*train.csv")[:3], "train")
merge_dataset(glob.glob("data/mol_trans/*valid.csv")[:3], "valid")
merge_dataset(glob.glob("data/mol_trans/*test.csv")[:3], "test")

  df = df.append(df_)
  df = df.append(df_)
  df = df.append(df_)
  df = df.append(df_)
  df = df.append(df_)
  df = df.append(df_)
  df = df.append(df_)
  df = df.append(df_)
  df = df.append(df_)


In [4]:
import glob
import pandas as pd

flist = glob.glob("data/mol_trans/*train.csv")[:3]
flist

['data/mol_trans/DAVIS_train.csv',
 'data/mol_trans/BindingDB_train.csv',
 'data/mol_trans/BIOSNAP_train.csv']