In [1]:
import pandas as pd
import torch
from transformers import T5Tokenizer, T5Model
import requests
import pickle

In [2]:
df1 = pd.read_csv('new_chembl_inhibit_drug_target.csv')
df2 = pd.read_csv('cancer_ppi_combined.csv')

In [3]:
protein_list1 = df1['target_accession_number']

protein_list2 = df2['node1_uniprot_id']

protein_list3 = df2['node2_uniprot_id']

protein_list = pd.concat([protein_list1, protein_list2, protein_list3], ignore_index=True)

protein_list = protein_list.drop_duplicates()

print(len(protein_list))

5914


In [4]:
# Load the Prot5 model and tokenizer
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")
print("Tokenizer loaded")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Tokenizer loaded


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = T5Model.from_pretrained("Rostlab/prot_t5_xl_uniref50").half().to(device)
print("Model loaded")

Model loaded


In [6]:
protein_list_1 = protein_list[:2000]
protein_list_2 = protein_list[2000:4000]
protein_list_3 = protein_list[4000:]

In [7]:
# Function to fetch protein sequence from UniProt
def fetch_uniprot_sequence(uniprot_id):
    url = f"https://www.uniprot.org/uniprot/{uniprot_id}.fasta"
    response = requests.get(url)
    if response.status_code == 200:
        sequence = ''.join(response.text.splitlines()[1:])  # Skip the FASTA header
        return sequence
    else:
        print(f"Error fetching sequence for {uniprot_id}")
        return None

# Function to generate embeddings for a list of protein sequences
def get_protein_embeddings(uniprot_ids):
    embeddings_dict = {}  # Dictionary to cache embeddings
    num = 0

    for uniprot_id in uniprot_ids:
        num += 1
        if num%20 == 0:
            print(f"Processing uniprot id : {uniprot_id} number {num}")

        try:
            sequence = fetch_uniprot_sequence(uniprot_id)  # Fetch the protein sequence
            if sequence:
                # Tokenize the input sequence
                inputs = tokenizer(sequence, return_tensors="pt", padding=True).to(device)

                # Add decoder_input_ids: initialize with the pad token id
                decoder_input_ids = torch.tensor([[tokenizer.pad_token_id]]).to(device)

                # Forward pass with encoder input and decoder input
                with torch.no_grad():
                    outputs = model(input_ids=inputs['input_ids'], decoder_input_ids=decoder_input_ids)

                # Extract embeddings (e.g., from encoder output or mean pooling)
                embedding = outputs.last_hidden_state.mean(dim=1)  # Mean pooling for simplicity
                embeddings_dict[uniprot_id] = embedding.squeeze(0).cpu().numpy()  # Cache the embedding
            else:
                embeddings_dict[uniprot_id] = [0] * 1024  # Default zero embedding for missing sequences
        except:
            print(f"Error processing uniprot id : {uniprot_id} number {num}")

    return embeddings_dict


# Generate the embeddings for the unique UniProt IDs
embeddings_dict_1 = get_protein_embeddings(protein_list_1)
print("Embeddings generated for protein_list_1")

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Processing uniprot id : P14780 number 20
Processing uniprot id : O75385 number 40
Processing uniprot id : Q9UNQ0 number 60
Error processing uniprot id : P09958 number 68
Processing uniprot id : Q92889 number 80
Processing uniprot id : P04049 number 100
Error processing uniprot id : Q9UEW8 number 113
Error processing uniprot id : Q86V86 number 117
Processing uniprot id : P49674 number 120
Error processing uniprot id : P07237 number 121
Error processing uniprot id : P49336 number 130
Processing uniprot id : Q05397 number 140
Processing uniprot id : P22607 number 160
Processing uniprot id : Q96GG9 number 180
Processing uniprot id : Q13315 number 200
Processing uniprot id : P11274 number 220
Processing uniprot id : Q13535 number 240
Processing uniprot id : O00141 number 260
Processing uniprot id : P53667 number 280
Processing uniprot id : P00403 number 300
Processing uniprot id : O00571 number 320
Processing uniprot id : P17812 number 340
Processing uniprot id : O15245 number 360
Processin

In [8]:
embeddings_dict_2 = get_protein_embeddings(protein_list_2)
print("Embeddings generated for protein_list_2")

Processing uniprot id : P84101 number 20
Processing uniprot id : Q9UJ83 number 40
Processing uniprot id : Q96M29 number 60
Processing uniprot id : O00629 number 80
Processing uniprot id : Q12873 number 100
Error processing uniprot id : Q12873 number 100
Error processing uniprot id : Q9UGK3 number 117
Processing uniprot id : Q07065 number 120
Processing uniprot id : O15123 number 140
Processing uniprot id : Q14493 number 160
Processing uniprot id : Q9UPV9 number 180
Processing uniprot id : P23528 number 200
Processing uniprot id : Q9Y237 number 220
Processing uniprot id : P25929 number 240
Processing uniprot id : O94842 number 260
Processing uniprot id : P55010 number 280
Processing uniprot id : Q14498 number 300
Processing uniprot id : Q9P2N2 number 320
Processing uniprot id : Q9Y4X5 number 340
Processing uniprot id : Q99717 number 360
Processing uniprot id : Q12982 number 380
Processing uniprot id : P16435 number 400
Processing uniprot id : Q92692 number 420
Processing uniprot id : O9

In [9]:
embeddings_dict_3 = get_protein_embeddings(protein_list_3)
print("Embeddings generated for protein_list_3")

Processing uniprot id : Q14643 number 20
Processing uniprot id : O43246 number 40
Processing uniprot id : Q9BZD6 number 60
Processing uniprot id : Q53GZ6 number 80
Processing uniprot id : Q9H201 number 100
Processing uniprot id : Q7Z739 number 120
Processing uniprot id : Q9NWU5 number 140
Processing uniprot id : Q7Z7H5 number 160
Error processing uniprot id : Q9HC36 number 174
Processing uniprot id : P0DOY2 number 180
Processing uniprot id : Q6WCQ1 number 200
Processing uniprot id : Q96SB8 number 220
Processing uniprot id : Q9UBI6 number 240
Processing uniprot id : Q9UK22 number 260
Processing uniprot id : Q5VZ89 number 280
Error processing uniprot id : Q9HCC0 number 288
Processing uniprot id : Q9BZF9 number 300
Error processing uniprot id : P09012 number 303
Processing uniprot id : O15400 number 320
Processing uniprot id : O95208 number 340
Processing uniprot id : P42167 number 360
Processing uniprot id : Q13277 number 380
Processing uniprot id : Q53GL0 number 400
Processing uniprot i

In [10]:
embeddings_dict = {**embeddings_dict_1, **embeddings_dict_2, **embeddings_dict_3}

In [13]:
print(len(embeddings_dict))

5914


In [12]:
while True:
    missing_ids = [m_id for m_id in protein_list if m_id not in embeddings_dict]
    if len(missing_ids) == 0:
        break
    print("Iteration ", len(missing_ids))
    new_dict = get_protein_embeddings(missing_ids)
    embeddings_dict.update(new_dict)

Iteration  20
Processing uniprot id : B4DGN8 number 20


In [14]:
with open("embeddings_dict.pkl", "wb") as f:
    pickle.dump(embeddings_dict, f)