In [1]:
PROTGPS_PARENT_DIR = "/home/mt1022/codehub/protgps" # point to the protgps local repo

In [2]:
import sys
import os
sys.path.append(PROTGPS_PARENT_DIR) # append the path of protgps
from argparse import Namespace
import pickle
from tqdm import tqdm
import pandas as pd
import torch 
from protgps.utils.loading import get_object

In [15]:
from Bio import SeqIO

def read_fasta(file_path):
    ids = []
    sequences = []
    for record in SeqIO.parse(file_path, "fasta"):
        ids.append(record.id)
        sequences.append(str(record.seq))
    return ids, sequences

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
device

device(type='cuda', index=0)

In [5]:
COMPARTMENT_CLASSES = [
    "nuclear_speckle",
    "p-body",
    "pml-bdoy",
    "post_synaptic_density",
    "stress_granule",
    "chromosome",
    "nucleolus",
    "nuclear_pore_complex",
    "cajal_body",
    "rna_granule",
    "cell_junction",
    "transcriptional"
]

def load_model(snargs):
    """
    Loads classifier model from args file
    """
    modelpath = snargs.model_path
    model = get_object(snargs.lightning_name, "lightning")(snargs)
    model = model.load_from_checkpoint(
        checkpoint_path = modelpath,
        strict=not snargs.relax_checkpoint_matching,
        **{"args": snargs},
    )
    return model

@torch.no_grad()
def predict_condensates(model, sequences, batch_size=1, round=True):
    scores = []
    for i in tqdm(range(0, len(sequences), batch_size), ncols=100):
        batch = sequences[ i : (i + batch_size)]
        out = model.model({"x": batch})    
        s = torch.sigmoid(out['logit']).to("cpu")
        scores.append(s)
    scores = torch.vstack(scores)
    if round:
        scores = torch.round(scores, decimals=3)
    return scores

In [6]:
args = Namespace(**pickle.load(open(os.path.join(PROTGPS_PARENT_DIR, 'checkpoints/protgps/32bf44b16a4e770a674896b81dfb3729.args'),'rb'))) # assumes args file has been extracted in checkpoints/protgps
args.model_path = os.path.join(PROTGPS_PARENT_DIR, 'checkpoints/protgps/32bf44b16a4e770a674896b81dfb3729epoch=26.ckpt') # assumes checkpoint has been extracted in checkpoints/protgps
args.pretrained_hub_dir =  os.path.join(PROTGPS_PARENT_DIR, 'checkpoints/esm2') # should point to folder with ESM2 facebookresearch_esm_main directory
model = load_model(args)
model.eval()
model = model.to(device)

Using cache found in /home/mt1022/codehub/protgps/checkpoints/esm2/facebookresearch_esm_main
Using cache found in /home/mt1022/codehub/protgps/checkpoints/esm2/facebookresearch_esm_main


Using ESM hidden layers 6
Using ESM hidden layers 6


In [16]:
# tests
ids = ['O15116', 'P38432']
sequences = [
    # UniProt O15116
    "MNYMPGTASLIEDIDKKHLVLLRDGRTLIGFLRSIDQFANLVLHQTVERIHVGKKYGDIPRGIFVVRGENVVLLGEIDLEKESDTPLQQVSIEEILEEQRVEQQTKLEAEKLKVQALKDRGLSIPRADTLDEY", 
    # Uniprot P38432
    "MAASETVRLRLQFDYPPPATPHCTAFWLLVDLNRCRVVTDLISLIRQRFGFSSGAFLGLYLEGGLLPPAESARLVRDNDCLRVKLEERGVAENSVVISNGDINLSLRKAKKRAFQLEEGEETEPDCKYSKKHWKSRENNNNNEKVLDLEPKAVTDQTVSKKNKRKNKATCGTVGDDNEEAKRKSPKKKEKCEYKKKAKNPKSPKVQAVKDWANQRCSSPKGSARNSLVKAKRKGSVSVCSKESPSSSSESESCDESISDGPSKVTLEARNSSEKLPTELSKEEPSTKNTTADKLAIKLGFSLTPSKGKTSGTTSSSSDSSAESDDQCLMSSSTPECAAGFLKTVGLFAGRGRPGPGLSSQTAGAAGWRRSGSNGGGQAPGASPSVSLPASLGRGWGREENLFSWKGAKGRGMRGRGRGRGHPVSCVVNRSTDNQRQQQLNDVVKNSSTIIQNPVETPKKDYSLLPLLAAAPQVGEKIAFKLLELTSSYSPDVSDYKEGRILSHNPETQQVDIEILSSLPALREPGKFDLVYHNENGAEVVEYAVTQESKITVFWKELIDPRLIIESPSNTSSTEPA" 
]

In [17]:
scores = predict_condensates(model, sequences, batch_size=1)

100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 33.61it/s]


In [18]:
data = {"id": ids, "seq": sequences}
for j,condensate in enumerate(COMPARTMENT_CLASSES):
    data[f"{condensate.upper()}_Score"] = scores[:, j].tolist()

In [19]:
pd.DataFrame(data)

Unnamed: 0,id,seq,NUCLEAR_SPECKLE_Score,P-BODY_Score,PML-BDOY_Score,POST_SYNAPTIC_DENSITY_Score,STRESS_GRANULE_Score,CHROMOSOME_Score,NUCLEOLUS_Score,NUCLEAR_PORE_COMPLEX_Score,CAJAL_BODY_Score,RNA_GRANULE_Score,CELL_JUNCTION_Score,TRANSCRIPTIONAL_Score
0,O15116,MNYMPGTASLIEDIDKKHLVLLRDGRTLIGFLRSIDQFANLVLHQT...,0.0,0.999,0.0,0.0,0.0,0.0,0.001,0.001,0.0,0.0,0.0,0.0
1,P38432,MAASETVRLRLQFDYPPPATPHCTAFWLLVDLNRCRVVTDLISLIR...,0.0,0.0,0.003,0.0,0.0,0.0,0.004,0.002,0.992,0.0,0.0,0.0


### ncORF prediction

In [31]:
ids, seqs = read_fasta('/home/mt1022/poj/ncorf_mammals/data/ncorf_v2412/human_three_methods_ncorf.final.pep.fa')

In [32]:
scores = predict_condensates(model, seqs, batch_size=2)

100%|██████████████████████████████████████████████████████████| 5812/5812 [00:47<00:00, 122.23it/s]


In [33]:
data = {"id": ids, "seq": seqs}
for j,condensate in enumerate(COMPARTMENT_CLASSES):
    data[f"{condensate.upper()}_Score"] = scores[:, j].tolist()

ncorf_protgps = pd.DataFrame(data)
ncorf_protgps

Unnamed: 0,id,seq,NUCLEAR_SPECKLE_Score,P-BODY_Score,PML-BDOY_Score,POST_SYNAPTIC_DENSITY_Score,STRESS_GRANULE_Score,CHROMOSOME_Score,NUCLEOLUS_Score,NUCLEAR_PORE_COMPLEX_Score,CAJAL_BODY_Score,RNA_GRANULE_Score,CELL_JUNCTION_Score,TRANSCRIPTIONAL_Score
0,14_73147866_73148072_+,MKERTSRGFVFCETVFLYSCSNDRVTCTVVLLPECTDV,0.0,0.003,0.0,0.000,0.0,1.000,0.000,0.000,0.000,0.0,0.000,0.0
1,14_73147810_73147869_+,MQVTTAFAVLRQLGLEENT,0.0,0.000,0.0,0.000,0.0,0.946,0.000,0.996,0.000,0.0,0.000,0.0
2,14_73147802_73147819_+,MGACK,0.0,0.819,0.0,0.000,0.0,0.236,0.036,0.000,0.000,0.0,0.000,0.0
3,22_40346522_40346650_+,MVQSTLAGSQGWDGGWRRSWFARQLPLTSCLPLCQPGDVLRV,0.0,0.001,0.0,0.000,0.0,0.928,0.000,0.960,0.000,0.0,0.000,0.0
4,9_113393618_113393586_-,MPQEPSVPTN,0.0,0.250,0.0,0.000,0.0,0.269,0.303,0.000,0.000,0.0,0.000,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11618,5_84384852_84387847_+,MCTNRGGVRPLEGIGP,0.0,0.877,0.0,0.000,0.0,0.012,0.000,0.002,0.000,0.0,0.000,0.0
11619,4_6986980_6986942_-,MRNVLPKEIAQK,0.0,0.000,0.0,0.000,0.0,0.970,0.001,0.000,0.000,0.0,0.000,0.0
11620,7_105014024_105013986_-,MEAVRGTESLSF,0.0,0.000,0.0,0.000,0.0,1.000,0.080,0.584,0.000,0.0,0.000,0.0
11621,2_232550702_232550728_+,VATAAARG,0.0,0.000,0.0,0.547,0.0,0.000,0.351,0.000,0.001,0.0,0.000,0.0


In [35]:
ncorf_protgps.to_csv('/home/mt1022/poj/ncorf_mammals/data/ncorf_v2412/ProtGPS_ncORF_human.csv.gz', compression='gzip', index=False)

In [36]:
ids, seqs = read_fasta('/home/mt1022/poj/ncorf_mammals/data/ncorf_v2412/mouse_three_methods_ncorf.final.pep.fa')

In [39]:
scores = predict_condensates(model, seqs, batch_size=1)

100%|████████████████████████████████████████████████████████| 16485/16485 [01:16<00:00, 214.58it/s]


In [40]:
data = {"id": ids, "seq": seqs}
for j,condensate in enumerate(COMPARTMENT_CLASSES):
    data[f"{condensate.upper()}_Score"] = scores[:, j].tolist()

ncorf_protgps = pd.DataFrame(data)
ncorf_protgps

Unnamed: 0,id,seq,NUCLEAR_SPECKLE_Score,P-BODY_Score,PML-BDOY_Score,POST_SYNAPTIC_DENSITY_Score,STRESS_GRANULE_Score,CHROMOSOME_Score,NUCLEOLUS_Score,NUCLEAR_PORE_COMPLEX_Score,CAJAL_BODY_Score,RNA_GRANULE_Score,CELL_JUNCTION_Score,TRANSCRIPTIONAL_Score
0,3_91989103_91987748_-,MEVVPAAAVEAAPVGASSTPEAAVALAAAAATPEAVVALAAAVATL...,0.000,0.000,0.0,0.876,0.000,0.000,0.000,0.000,0.0,0.0,0.000,0.0
1,7_41130921_41129584_-,MKEVILEKNLLYILNVIKHLHIRVILKDIIAFILERNPMKVFNMIK...,0.000,0.000,0.0,0.000,0.000,0.458,0.000,0.007,0.0,0.0,0.015,0.0
2,17_56431599_56430760_-,MAGQRQRCCMVALTTWTRGCCPPQSRRLYPRNACPRHTHTPTPWGL...,0.000,0.000,0.0,0.000,0.000,0.024,0.290,0.000,0.0,0.0,0.002,0.0
3,7_45482681_45481926_-,MALPPLRQPGAATATTPSQLLARRPRRWLVGASASTLGCGATAPAP...,0.005,0.000,0.0,0.000,0.000,0.002,0.001,0.000,0.0,0.0,0.001,0.0
4,19_5738220_5736124_-,MRVPGGGRAQQRLGHGVSGVADPAPGGVGLAHWGMVLRSASEHLLQ...,0.000,0.001,0.0,0.000,0.003,0.904,0.001,0.000,0.0,0.0,0.000,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
16480,15_12321117_12321091_-,MFMYCALA,0.000,0.000,0.0,0.000,0.000,0.000,0.000,0.002,0.0,0.0,1.000,0.0
16481,2_68410854_68408984_-,MKQLGNQEKFTVRSLTCTPLSQFLFIVVNK,0.000,0.087,0.0,0.000,0.000,0.999,0.550,0.000,0.0,0.0,0.000,0.0
16482,7_141119925_141120011_+,MDEDRAECAERVLRSRRCCFRSCSQSLY,0.000,0.000,0.0,0.000,0.000,1.000,0.001,0.000,0.0,0.0,0.000,0.0
16483,16_16532785_16528785_-,MKTIDLYYTLLPVGDCGGPIILLLR,0.000,0.980,0.0,0.000,0.000,0.000,0.013,0.000,0.0,0.0,0.000,0.0


In [41]:
ncorf_protgps.to_csv('/home/mt1022/poj/ncorf_mammals/data/ncorf_v2412/ProtGPS_ncORF_mouse.csv.gz', compression='gzip', index=False)