#### STEPS

- We extracted all unique protein identifiers from BioKG in a previous notebook.
- We split them in 2 parts due to uniprot API restrictions
- We queried uniprot to get all amino-acid sequences per id
- We remove protein sequences that are TOO LONG
- We split the sequences in batches for efficiently extracting embeddings
- We then: 
   - load 
   - generate embedding
   - store

In [1]:
import pandas as pd
import numpy as np
import torch
import tqdm
import json
import dill

#### ID mapping creation

In [2]:
# The set of unique ids from BioKG metadata.
unique_biokg_prot_ids = pd.read_csv('../data/biokg/unique_proteins.csv', 
                                    index_col=0)

In [3]:
biokg_uniprot_ids_sequences_pt_0 = pd.read_csv('../data/biokg/uniprot/biokg_uniprot_id_sequences.tsv',
                                    sep='\t')



In [4]:
biokg_uniprot_ids_sequences_pt_1 = pd.read_csv('../data/biokg/uniprot/biokg_uniprot_id_sequences_part2.tsv',
                                    sep='\t')

In [5]:
len(biokg_uniprot_ids_sequences_pt_0), len(biokg_uniprot_ids_sequences_pt_1)

(61420, 61423)

In [6]:
# SORT BY LENGTH 

biokg_uniprot_ids_sequences_pt_1 = biokg_uniprot_ids_sequences_pt_1.sort_values('Length', ascending=False)

biokg_uniprot_ids_sequences_pt_0 = biokg_uniprot_ids_sequences_pt_0.sort_values('Length', ascending=False)

In [7]:
# EXTRACT THE EDGE CASES

edge_cases = pd.concat([biokg_uniprot_ids_sequences_pt_0[:5], biokg_uniprot_ids_sequences_pt_1[:5]])

In [8]:
edge_cases.head(3)

Unnamed: 0,From,Entry,Reviewed,Entry Name,Protein names,Gene Names,Organism,Length,Sequence
34528,A2ASS6,A2ASS6,reviewed,TITIN_MOUSE,Titin (EC 2.7.11.1) (Connectin),Ttn,Mus musculus (Mouse),35213,MTTQAPMFTQPLQSVVVLEGSTATFEAHVSGSPVPEVSWFRDGQVI...
4109,G4SLH0,G4SLH0,reviewed,TTN1_CAEEL,Titin homolog (EC 2.7.11.1),ttn-1 W06H8.8,Caenorhabditis elegans,18562,MEGNEKKGGGLPPTQQRHLNIDTTVGGSISQPVSPSMSYSTDRETV...
59931,Q9I7U4,Q9I7U4,reviewed,TITIN_DROME,Titin (D-Titin) (Kettin),sls titin CG1915,Drosophila melanogaster (Fruit fly),18141,MQRQNPNPYQQQNQQHQQVQQFSSQEYSHSSQEQHQEQRISRTEQH...


In [9]:
# REMOVE SEQUENCES THAT ARE TOO LONG (MEMORY ERROR)

biokg_uniprot_ids_sequences_pt_0 = biokg_uniprot_ids_sequences_pt_0[5:]


biokg_uniprot_ids_sequences_pt_1 = biokg_uniprot_ids_sequences_pt_1[5:]

In [10]:
# MERGET THEM BACK AFTER REMOVING EDGE CASES
biokg_uniprot_id_seq_full = pd.concat([biokg_uniprot_ids_sequences_pt_0, biokg_uniprot_ids_sequences_pt_1], ignore_index=True)

In [11]:
biokg_uniprot_id_seq_full = biokg_uniprot_id_seq_full[['From', 'Sequence', 'Length']]

In [12]:
## LET's CHUNK IT FOR PROTTRANS
n = 1000
chunks = [biokg_uniprot_id_seq_full[i:i+n].copy() for i in range(0,biokg_uniprot_id_seq_full.shape[0],n)]


In [13]:
len(chunks)

123

In [14]:
len(chunks)

for num, chunk in enumerate(chunks):
    chunk.to_csv(f'../data/biokg/uniprot/chunks/chunk_{num}.csv')

In [15]:
chunks

[       From                                           Sequence  Length
 0    Q6ZWR6  MATSRASSRSHRDITNVMQRLQDEQEIVQKRTFTKWINSHLAKRKP...    8799
 1    Q8NF91  MATSRGASRCPRDIANVMQRLQDEQEIVQKRTFTKWINSHLAKRKP...    8797
 2    Q9N4M4  MSSSPPARPCCVCFRFRPHEDEKAQKNTFTRWINFHLEEHSSSGRI...    8545
 3    Q7Z5P9  MKLILWYLVVALWCFFKDVEALLYRQKSDGKIAASRSGGFSYGSSS...    8384
 4    O01761  MASRRQKQFDRKYSSYRKFTATEDVNYSTHSSRSSYRSESLTSRTD...    8081
 ..      ...                                                ...     ...
 995  Q8C3J5  MAPWRKTDKERHGVAIYNFQGSEAQHLTLQIGDVVRIQETCGDWYR...    1828
 996  E9Q7E2  MANSTGKAPPDERRKGLAFLDELRQFHHSRGSPFKKIPAVGGKELD...    1828
 997  P9WQE3  MTSLAERAAQLSPNARAALARELVRAGTTFPTDICEPVAVVGIGCR...    1827
 998  P14410  MARKKFSGLEISLIVLFVIVTIIAIALIVVLATKTPAVDEISDSTS...    1827
 999  Q8JHV6  MLLRLELSALLLLLIAAPVRLQDECVGNSCYPNLGDLMVGRAAQLA...    1827
 
 [1000 rows x 3 columns],
         From                                           Sequence  Length
 1000  E9PZM4  MMRNKDKSQEEDSSLHSNAS


#### Embedding gen

In [16]:
import numpy as np
import bio_embeddings
from bio_embeddings.embed import SeqVecEmbedder, ProtTransBertBFDEmbedder, prottrans_t5_embedder, esm_embedder

In [17]:
# IF THIS CELLS EXECUTES FOR THE FIRST TIME, EXPECT A DELAY
prot_trans_embedder = ProtTransBertBFDEmbedder()

In [23]:
# Get the aggregated protein representation

def get_averaged_protein_repr(amino_repr, embedder):
    # Here we need to go from a collection of amino-acid embeddings to a full protein embedding
    #
    # Example: 
    #    
    #   M : (1,1024)
    #   A : (1,1024)
    #   S : (1,1024)
    #   
    #  Output: An aggregated representation for proteins
    #
    #  Type: Dict(protein_id: (embedding))
    #
    #  e.g Dict(: (LENG8_MOUSE, 1024))
    
    amino_repr_embedding = embedder.embed(amino_repr)
    emb_matrix = torch.Tensor(amino_repr_embedding)
    protein_emb = torch.mean(emb_matrix, dim=0)

    return protein_emb

In [19]:
# Get the aggregated protein representation

def get_matrix_protein_repr(amino_repr, embedder):
    # Here we need to go from a collection of amino-acid embeddings to a full protein embedding
    #
    # Example: 
    #    
    #   M : (1,1024)
    #   A : (1,1024)
    #   S : (1,1024)
    #   
    #  Output: An aggregated representation for proteins
    #
    #  Type: Dict(protein_id: (embedding))
    #
    #  e.g Dict(: (LENG8_MOUSE, 1024))
    
    amino_repr_embedding = embedder.embed(amino_repr)
    emb_matrix = torch.Tensor(amino_repr_embedding)

    return emb_matrix

In [20]:
# # TEST WITH 1 CHUNK

# chunk = pd.read_csv(f'../data/biokg/uniprot/chunks/chunk_59.csv')    
    
# print(len(chunk))
# chunk['embedding'] = None

# # Get emb for every batch
# for row in chunk.itertuples():
#     chunk.at[row.Index, 'embedding'] = torch.Tensor(get_protein_repr(row[3], prot_trans_embedder))

# df = chunk.set_index('From')

# # save every batch as a .pt
# df = df[['embedding']]

# # Make it a dict
# protein_emb = df.to_dict()


In [21]:
num_chunks = len(chunks)

In [24]:
for i in range(num_chunks):
    
    chunk = pd.read_csv(f'../data/biokg/uniprot/chunks/chunk_{i}.csv')    
    
    chunk['embedding'] = None
    
    # Get emb for every batch
    for row in chunk.itertuples():
        chunk.at[row.Index, 'embedding'] = torch.Tensor(get_averaged_protein_repr(row[3], prot_trans_embedder))
    
    df = chunk.set_index('From')
    
    # save every batch as a .pt
    df = df[['embedding']]

    # Make it a dict
    protein_emb = df.to_dict()

    print(f"Saving protein embeddings......{i}")
    # Save it
    filename = f'../data/biokg/uniprot/embeddings/{i}_protein_batch.pt'

    with open(filename, 'wb') as f:  # Overwrites any existing file.
        torch.save(protein_emb, f, pickle_module=dill)

Saving protein embeddings......0
Saving protein embeddings......1
Saving protein embeddings......2
Saving protein embeddings......3
Saving protein embeddings......4
Saving protein embeddings......5
Saving protein embeddings......6
Saving protein embeddings......7
Saving protein embeddings......8
Saving protein embeddings......9
Saving protein embeddings......10
Saving protein embeddings......11
Saving protein embeddings......12
Saving protein embeddings......13
Saving protein embeddings......14
Saving protein embeddings......15
Saving protein embeddings......16
Saving protein embeddings......17
Saving protein embeddings......18
Saving protein embeddings......19
Saving protein embeddings......20
Saving protein embeddings......21
Saving protein embeddings......22
Saving protein embeddings......23
Saving protein embeddings......24
Saving protein embeddings......25
Saving protein embeddings......26
Saving protein embeddings......27
Saving protein embeddings......28
Saving protein embedding

In [None]:
chunks_test

In [None]:
# Loop that generates all embeddings for all proteins
for count, df in enumerate(df_list):
    
    # Get emb for every batch
    for row in df.itertuples():
        df.at[row.Index, 'embedding'] = torch.Tensor(get_protein_repr(row[1], prot_trans_embedder))
        
    df = df.set_index('From')
    
    # save every batch as a .pt
    df = df[['embedding']]

    # Make it a dict
    protein_emb = df.to_dict()

    print(f"Saving protein embeddings......")
    # Save it
    filename = f'../data/processed/{count}_protein_batch.pt'

    with open(filename, 'wb') as f:  # Overwrites any existing file.
        torch.save(protein_emb, f, pickle_module=dill)

In [None]:
for row in protein_sequences_h.itertuples():
    protein_sequences_h.at[row.Index, 'embedding'] = torch.Tensor(get_protein_repr(row[1], prot_trans_embedder))
        
df = protein_sequences_h.set_index('From')

# save every batch as a .pt
df = df[['embedding']]

# Make it a dict
protein_emb = df.to_dict()

print(f"Saving protein embeddings......")
# Save it
filename = f'../data/processed/7_protein_batch.pt'

with open(filename, 'wb') as f:  # Overwrites any existing file.
    torch.save(protein_emb, f, pickle_module=dill)

In [None]:
protein_sequences_a.head(3)

In the above example we see a protein of 406 amino-acids is represented by a (406,1024) matrix.

To get the final representation we will make it a (1,1024) by "squashing" the amino-acids together.

In [None]:
### Load embeddings - merge - store them for use down the line

In [None]:
test_prot = protein_sequences.set_index('From')

In [None]:
test_prot.sort_index()

In [None]:
test_tensor = torch.load(f'../data/biokg/uniprot/embeddings/1_protein_batch.pt')
test_df = pd.DataFrame.from_dict(test_tensor)

In [None]:
test_df