### Protein embeddings were generated using ANKH LLM using protein sequences

In [1]:
from pathlib import Path
import pandas as pd


#
base_path = Path('path_to_data_directory')


bp_train_path = base_path / 'bp/train_data.pkl'
bp_valid_path = base_path / 'bp/valid_data.pkl'
bp_test_path = base_path / 'bp/test_data.pkl'

cc_train_path = base_path / 'cc/train_data.pkl'
cc_valid_path = base_path / 'cc/valid_data.pkl'
cc_test_path = base_path / 'cc/test_data.pkl'

mf_train_path = base_path / 'mf/train_data.pkl'
mf_valid_path = base_path / 'mf/valid_data.pkl'
mf_test_path = base_path / 'mf/test_data.pkl'


def preprocess(data_path, data_type, ont):
    data = pd.read_pickle(data_path)
    data.rename(columns={'prop_annotations': 'term'}, inplace=True)
    data = data[['proteins', 'sequences', 'term']].rename(columns={'proteins': 'protein_name'})
    data['Set'] = data_type
    data['aspect'] = ont
    return data


bp_train = preprocess(bp_train_path, "Train", "BPO")
cc_train = preprocess(cc_train_path, "Train", "CCO")
mf_train = preprocess(mf_train_path, "Train", "MFO")

bp_valid = preprocess(bp_valid_path, "Valid", "BPO")
cc_valid = preprocess(cc_valid_path, "Valid", "CCO")
mf_valid = preprocess(mf_valid_path, "Valid", "MFO")

bp_test = preprocess(bp_test_path, "Test", "BPO")
cc_test = preprocess(cc_test_path, "Test", "CCO")
mf_test = preprocess(mf_test_path, "Test", "MFO")

# Concatenate 
mf = pd.concat([mf_train, mf_valid, mf_test], ignore_index=True)
cc = pd.concat([cc_train, cc_valid, cc_test], ignore_index=True)
bp = pd.concat([bp_train, bp_valid, bp_test], ignore_index=True)

data = pd.concat([bp, cc, mf], ignore_index=True)


In [6]:
data.head()

Unnamed: 0,protein_name,sequences,term,Set,aspect
0,VGFR2_MOUSE,MESKALLAVALWFCVETRAASVGLPGDFLHPPKLSTQKDILTILAN...,"[GO:0004713, GO:0016310, GO:0002040, GO:004259...",Train,BPO
1,VGFR2_RAT,MESRALLAVALWFCVETRAASVGLPGDSLHPPKLSTQKDILTILAN...,"[GO:0060548, GO:0050804, GO:0004713, GO:001631...",Train,BPO
2,VGFR2_HUMAN,MQSKVLLAVALWLCVETRAASVGLPSVSLDLPRLSIQKDILTIKAN...,"[GO:0043536, GO:0002040, GO:0008283, GO:004259...",Train,BPO
3,VGFR2_DANRE,MAKTSYALLLLDILLTFNVAKAIELRFVPDPPTLNITEKTIKINAS...,"[GO:0001525, GO:0032502, GO:0007275, GO:000815...",Train,BPO
4,KIT_MOUSE,MRGARGAWDLLCVLLVLLRGQTATSQPSASPGEPSPPSIHPAQSEL...,"[GO:0006664, GO:0002371, GO:0036216, GO:004259...",Train,BPO


In [7]:
df_seq = data[['protein_name', 'sequences']].drop_duplicates().reset_index(drop = True)
df_seq.head()

Unnamed: 0,protein_name,sequences
0,VGFR2_MOUSE,MESKALLAVALWFCVETRAASVGLPGDFLHPPKLSTQKDILTILAN...
1,VGFR2_RAT,MESRALLAVALWFCVETRAASVGLPGDSLHPPKLSTQKDILTILAN...
2,VGFR2_HUMAN,MQSKVLLAVALWLCVETRAASVGLPSVSLDLPRLSIQKDILTIKAN...
3,VGFR2_DANRE,MAKTSYALLLLDILLTFNVAKAIELRFVPDPPTLNITEKTIKINAS...
4,KIT_MOUSE,MRGARGAWDLLCVLLVLLRGQTATSQPSASPGEPSPPSIHPAQSEL...


### ANKH LLM

In [8]:
import torch
import ankh
import tqdm
import numpy as np


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model, tokenizer = ankh.load_large_model()
model.to(device)
model.eval()


def get_ankh_embeddings(sequence):
    tokens = tokenizer.encode_plus(sequence, add_special_tokens=True, return_tensors='pt').to(device)

    with torch.no_grad():
        embeddings = model(input_ids=tokens['input_ids'], attention_mask=tokens['attention_mask'])
    
    return embeddings.last_hidden_state.mean(dim=1).detach().cpu().numpy()  # Average pooling over tokens


def process_sequences(df, embedding_file_prefix):
    ids = []
    embeds = []
    for index, row in tqdm.tqdm(df_seq.head(10).iterrows(), total=len(df_seq), desc="Processing sequences"):
        seq = row['sequences']
        protein_id = row['protein_name']
        
        seq = seq[:3000]
        embeddings = get_ankh_embeddings(seq)
        ids.append(protein_id)
        embeds.append(embeddings)
    

    np.save(f'{embedding_file_prefix}_zerogo_ids.npy', np.array(ids))
    np.save(f'{embedding_file_prefix}_zerogo_embeds.npy', np.array(embeds))


process_sequences(df_seq, embedding_file_prefix='ankh_embeddings')

Processing sequences:   0%|                | 10/77638 [00:02<4:50:48,  4.45it/s]
