original code from prot gps github repo, link to notebook: https://github.com/pgmikhael/protgps/blob/main/notebook/Predict.ipynb

In [1]:
PROTGPS_PARENT_DIR = "/Users/marcellamirabelli/Desktop/rbp_localization_project/models/ProtGPS/protgps-main" # point to the protgps local repo

In [2]:
import sys
import os
sys.path.append(PROTGPS_PARENT_DIR) # append the path of protgps
from argparse import Namespace
import pickle
from tqdm import tqdm
import pandas as pd
import torch 
from protgps.utils.loading import get_object

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
COMPARTMENT_CLASSES = [
    "nuclear_speckle",
    "p-body",
    "pml-bdoy",
    "post_synaptic_density",
    "stress_granule",
    "chromosome",
    "nucleolus",
    "nuclear_pore_complex",
    "cajal_body",
    "rna_granule",
    "cell_junction",
    "transcriptional"
]

def load_model(snargs):
    """
    Loads classifier model from args file
    """
    modelpath = snargs.model_path
    model = get_object(snargs.lightning_name, "lightning")(snargs)
    model = model.load_from_checkpoint(
        checkpoint_path = modelpath,
        strict=not snargs.relax_checkpoint_matching,
        **{"args": snargs},
    )
    return model

@torch.no_grad()
def predict_condensates(model, sequences, batch_size=1, round=True):
    scores = []
    for i in tqdm(range(0, len(sequences), batch_size), ncols=100):
        batch = sequences[ i : (i + batch_size)]
        out = model.model({"x": batch})    
        s = torch.sigmoid(out['logit']).to("cpu")
        scores.append(s)
    scores = torch.vstack(scores)
    if round:
        scores = torch.round(scores, decimals=3)
    return scores

In [5]:
args = Namespace(**pickle.load(open(os.path.join(PROTGPS_PARENT_DIR, 'protgps/checkpoints/protgps/32bf44b16a4e770a674896b81dfb3729.args'), 'rb')))
args.model_path = os.path.join(PROTGPS_PARENT_DIR, 'protgps/checkpoints/protgps/32bf44b16a4e770a674896b81dfb3729epoch=26.ckpt')
args.pretrained_hub_dir = "/users/marcellamirabelli/protgps/esm_models/esm2" # should point to folder with ESM2 facebookresearch_esm_main directory
model = load_model(args)
model.eval()
model = model.to(device)


Using cache found in /users/marcellamirabelli/protgps/esm_models/esm2/facebookresearch_esm_main


Using ESM hidden layers 6
Using ESM hidden layers 6


Using cache found in /users/marcellamirabelli/protgps/esm_models/esm2/facebookresearch_esm_main


In [6]:
sequences = [
    # UniProt O15116
    "MNYMPGTASLIEDIDKKHLVLLRDGRTLIGFLRSIDQFANLVLHQTVERIHVGKKYGDIPRGIFVVRGENVVLLGEIDLEKESDTPLQQVSIEEILEEQRVEQQTKLEAEKLKVQALKDRGLSIPRADTLDEY", 
    # Uniprot P38432
    "MAASETVRLRLQFDYPPPATPHCTAFWLLVDLNRCRVVTDLISLIRQRFGFSSGAFLGLYLEGGLLPPAESARLVRDNDCLRVKLEERGVAENSVVISNGDINLSLRKAKKRAFQLEEGEETEPDCKYSKKHWKSRENNNNNEKVLDLEPKAVTDQTVSKKNKRKNKATCGTVGDDNEEAKRKSPKKKEKCEYKKKAKNPKSPKVQAVKDWANQRCSSPKGSARNSLVKAKRKGSVSVCSKESPSSSSESESCDESISDGPSKVTLEARNSSEKLPTELSKEEPSTKNTTADKLAIKLGFSLTPSKGKTSGTTSSSSDSSAESDDQCLMSSSTPECAAGFLKTVGLFAGRGRPGPGLSSQTAGAAGWRRSGSNGGGQAPGASPSVSLPASLGRGWGREENLFSWKGAKGRGMRGRGRGRGHPVSCVVNRSTDNQRQQQLNDVVKNSSTIIQNPVETPKKDYSLLPLLAAAPQVGEKIAFKLLELTSSYSPDVSDYKEGRILSHNPETQQVDIEILSSLPALREPGKFDLVYHNENGAEVVEYAVTQESKITVFWKELIDPRLIIESPSNTSSTEPA" ,
    #test tile from final 115 library
    "KKRRKEQEEKAEIKRLKNSDDRDSKRDSLEEGELRDHRMEITIRNSPYRREDSMEDRGEEDDSLAIKPPQQMSRKEKVHHRKDEKRKEKRRHRSHSAEGGKHARVKEKEREHERR",

    

]

In [7]:
scores = predict_condensates(model, sequences, batch_size=1)

100%|█████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  7.33it/s]


In [8]:
data = {"sequences": sequences}
for j,condensate in enumerate(COMPARTMENT_CLASSES):
    data[f"{condensate.upper()}_Score"] = scores[:, j].tolist()

In [9]:
pd.DataFrame(data)

Unnamed: 0,sequences,NUCLEAR_SPECKLE_Score,P-BODY_Score,PML-BDOY_Score,POST_SYNAPTIC_DENSITY_Score,STRESS_GRANULE_Score,CHROMOSOME_Score,NUCLEOLUS_Score,NUCLEAR_PORE_COMPLEX_Score,CAJAL_BODY_Score,RNA_GRANULE_Score,CELL_JUNCTION_Score,TRANSCRIPTIONAL_Score
0,MNYMPGTASLIEDIDKKHLVLLRDGRTLIGFLRSIDQFANLVLHQT...,0.0,0.999,0.0,0.0,0.0,0.0,0.001,0.001,0.0,0.0,0.0,0.0
1,MAASETVRLRLQFDYPPPATPHCTAFWLLVDLNRCRVVTDLISLIR...,0.0,0.0,0.003,0.0,0.0,0.0,0.004,0.002,0.992,0.0,0.0,0.0
2,KKRRKEQEEKAEIKRLKNSDDRDSKRDSLEEGELRDHRMEITIRNS...,0.199,0.005,0.0,0.0,0.0,0.0,0.988,0.0,0.0,0.0,0.0,0.0
