In [1]:
import os
import re
import pickle
import torch
import tqdm
import pandas as pd
import numpy as np
from transformers import T5EncoderModel, T5Tokenizer

In [2]:
input_dir = "data_preprocessed"

In [3]:
output_dir = "data_embeddings"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    
output_dir = os.path.join(output_dir, "pdbbind")
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

# 1. Data Load

In [4]:
prots_df = pd.read_csv(os.path.join(input_dir, "PDBbind_data.tsv"), sep="\t")

In [5]:
prots_df

Unnamed: 0,PDB,Sequence,BS
0,4x14,VEVLEVKTGVDSITEVECFLTPEMGDPDEHLRGFSKSISISDTFES...,"361,362,364,367,368,369,370,371,372,373,374,50..."
1,4ruu,TRDQNGTWEMESNENFEGYMKALDIDFATRKIAVRLTQTLVIDQDG...,"1,2,3,4,5,7,9,12,14,15,16,17,18,19,20,21,22,23..."
2,5hx8,IVSEKKPATEVDPTHFEKRFLKRIRDLGEGHFGKVELCRYDPEGDN...,"23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,4..."
3,2ymt,SAPIPDLKVFEREGVQLNLSFIRPPENPALLLITITATNFSEGDVT...,"16,18,32,34,47,48,49,50,51,52,53,54,55,56,57,5..."
4,4km2,MVGLIWAQATSGVIGRGGDIPWRLPEDQAHFREITMGHTIVMGRRT...,"2,3,4,5,6,7,8,13,14,15,17,18,19,20,21,22,23,24..."
...,...,...,...
9973,5ew9,QWALEDFEIGRPLGKGKFGNVYLAREKQSKFILALKVLFKAQLEKA...,"10,11,12,13,14,15,16,17,18,19,20,21,22,23,30,3..."
9974,4f7l,DKMDYDFKVKLSSERERVEDLFEYEGCKVGRGTYGHVYKAKRKDGK...,"26,27,28,29,30,31,32,33,34,35,36,37,38,39,49,5..."
9975,4elh,MIVSFMVAMDENRVIGKDNNLPWRLPSELQYVKKTTMGHPLIMGRK...,"3,4,5,6,7,8,9,12,13,14,15,16,17,18,19,20,21,22..."
9976,4o3a,GANKTVVVTTILESPYVMMKKNHEMLEGNERYEGYCVDLAAEIAKH...,"9,10,11,12,13,14,15,35,55,58,59,60,61,62,63,70..."


# 2. Tokenizer and Pretrained Model

In [6]:
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False )

In [7]:
## https://stackoverflow.com/questions/71788825/using-the-encoder-part-only-from-t5-model
T5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*", "lm_head.*"]
prots_model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")

In [8]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [9]:
prots_model = prots_model.to(device)
prots_model = prots_model.eval()

# 3. Feature Generation

In [10]:
cnt = 0
for seqs in prots_df["Sequence"].values:
    if len(seqs.split(",")) > 1:
        cnt += 1
        
print(cnt)

4876


In [11]:
for i in tqdm.trange(len(prots_df)):
    ## seqs
    seqs = prots_df.loc[i, "Sequence"].split(",")
    
    ## tokenization
    batch = tokenizer.batch_encode_plus(
        [" ".join(list(re.sub(r"[UZOB]", "X", seq))) for seq in seqs],
        add_special_tokens=True,
        padding=True
    )
    
    ## embedding
    with torch.no_grad():
        embeddings = prots_model(
            input_ids=torch.tensor(batch["input_ids"], device=device),
            attention_mask=torch.tensor(batch["attention_mask"], device=device),
        )[0].cpu().numpy()
    
    ## postprocess for the special token
    embedding = np.vstack([
        emb[:len(seq)] for emb, seq in zip(embeddings, seqs)
    ])
    
    ## store    
    pdbid = prots_df.loc[i, "PDB"]
    np.save(os.path.join(output_dir, f"{pdbid}.npy"), embedding)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9978/9978 [8:12:57<00:00,  2.96s/it]
