In [1]:
%load_ext autoreload
%autoreload 2

import sys
import yaml
import json
import numpy as np
import pandas as pd
from pytorch_lightning import Trainer

sys.path.append('..')
sys.path.append('pLMtrainer')
from pLMtrainer.models.frustraSeq import FrustraSeq
from pLMtrainer.dataloader import FrustrationDataModule

  from .autonotebook import tqdm as notebook_tqdm


# general info for script

## 1) Load Data
First we load a simple fasta file containing the headers/ids and their sequences and store it in a dictionary of the form {id: sequence}

In [2]:
fasta_file_path = "../data/frustration/bonomi_ensembles_sequences.fasta"

In [3]:
seqs = {}
with open(fasta_file_path, 'r') as f:
    fasta_data = f.read()
    for line in fasta_data.splitlines():
        if line.startswith(">"):
            header = line[1:]
            seqs[header] = ""
        else:
            seqs[header] += line.strip()
seqs

{'Alb3-A3CT': 'MDENASKIISAGRAKRSIAQPDDAGERFRQLKEQEKRSKKNKAVAKDTVELVEESQSESEEGSDDEEEEAREGALASSTTSKPLPEVGQRRSKRSKRKRTV',
 'FCP1': 'PGPEEQEEEPQPRKPGTRRERTLGAPASSERSAAGGRGPRGHKRKLNEEDAASESSRESSNEDEGSSSEADEMAKALEAELNDLM',
 'emerin_67-170': 'GTRGDADMYDLPKKEDALLYQSKGYNDDYYEESYFTTRTYGEPESAGPSRAVRQSVTSFPDADAFHHQVHDDDLLSSSEEECKDRERPMYGRDSAYQSITHYRPV',
 'UBact': 'MIQSLMPERRERPGDPMPKSPSPLEEGGGPRRPETGSPDKDSLLKRMRRVDPKQAERYRQRTGE',
 'Nsp2_CtlIDR': 'KEIIFLEGETLPTEVLTEEVVLKTGDLQPLEQPTSEAVEAPLVGT',
 'NHE1': 'MINNYLTVPAHKLDSPTMSRARIGSDPLAYEPKEDLPVITIDPASPQSPESVDLVNEELKGKVLGLSRDPAKVAEEDEDDDGGIMMRSKETSSPGTDDVFTPAPSDSPSSQRIQRCLSDP',
 'p61_Hck': 'GGRSSCEDPGCPRDEERAPRMGCMKSKFLQVGGNTFSKTETSASPHCPVYVPDPTSTIKPGPNSHNSNTPGIREAGSE',
 'ACTR': 'GTQNRPLLRNSLDDLVGPPSNLEGQSDERALLDQLHTLLSNTDATGLEEIDRALGIPELVNQGQALEPKQD',
 'Hug1': 'AMADPMTMDQGLNPKQFFLDDVVLQDTLCSMSNRVNKSVKTGYLFPKDHVPSANIIAVERRGGLSDIGKNTSN',
 'PaaA2': 'MDYKDDDDKNRALSPMVSEFETIEQENSYNEWLRAKVATSLADPRPAIPHDEVERRMAERFAKMRKERSKQ',
 'Nt-SOCS5': 'RSLRQRLQDTVGLCFPM

In [4]:
df = pd.DataFrame.from_dict(seqs, orient='index', columns=["sequence"]).reset_index().rename(columns={'index':'id'})

In [5]:
df.head()

Unnamed: 0,id,sequence
0,Alb3-A3CT,MDENASKIISAGRAKRSIAQPDDAGERFRQLKEQEKRSKKNKAVAK...
1,FCP1,PGPEEQEEEPQPRKPGTRRERTLGAPASSERSAAGGRGPRGHKRKL...
2,emerin_67-170,GTRGDADMYDLPKKEDALLYQSKGYNDDYYEESYFTTRTYGEPESA...
3,UBact,MIQSLMPERRERPGDPMPKSPSPLEEGGGPRRPETGSPDKDSLLKR...
4,Nsp2_CtlIDR,KEIIFLEGETLPTEVLTEEVVLKTGDLQPLEQPTSEAVEAPLVGT


## 2) Load Model
We then load the checkpoint of our trained model and create a dataloader from the dictionary

In [6]:
# load config
with open(f"../data/it5_ABL_protT5/config.yaml", 'r') as f:
    config = yaml.safe_load(f)
config["experiment_name"]

'it5_ABL_protT5'

In [None]:
# either provide path to pretrained model or 
# set to huggingface model name (e.g. "Rostlab/prot_t5_xl_uniref50" for protT5-xl) 
config["pLM_model"] = "Rostlab/prot_t5_xl_uniref50"

In [8]:
model = FrustraSeq.load_from_checkpoint(checkpoint_path=f"../data/{config['experiment_name']}/best_val_model.ckpt",
                                        config=config)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


RANK -1: Model initialized.


In [9]:
predict_dataloader = FrustrationDataModule(df=df,
                                            max_seq_length=df["sequence"].str.len().max(),
                                            batch_size=5,
                                            num_workers=1,
                                            persistent_workers=True,)

In [10]:
# adding the surprisal dictionary which is used to compute the surprisal feature during inference based on 
# precomputed values (for each aa) from the train set
with open('../data/frustration/reg_heuristic.json', 'r') as f:
    model.surprisal_dict = json.load(f)
model.surprisal_dict["A"]

{'mean': 0.24633155516241328, 'std': 0.5921655729624687}

## 3) Inference

Lets define a Lightning trainer and run inference (prediction)

In [11]:
trainer = Trainer(accelerator='mps',) # use 'gpu' instead of 'mps' on cuda enabled devices or 'cpu' for cpu only

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [12]:
trainer.predict(model, predict_dataloader)

Loaded 16 sequences for prediction.
Created test dataset for prediction
Test dataset size: 16 samples


/Users/janleusch/anaconda3/envs/biotrainer/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Predicting DataLoader 0:  25%|â–ˆâ–ˆâ–Œ       | 1/4 [00:03<00:10,  0.29it/s]



Predicting DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4/4 [00:10<00:00,  0.37it/s]


[None, None, None, None]

In [13]:
# model.pred_list contains a list of prediction dictionaries (one per protein)
len(model.pred_list), model.pred_list[0].keys()

(16,
 dict_keys(['residue', 'regression', 'classification', 'entropy', 'surprisal']))

We then can either create protein specific dataframes or one combined one for all protein sequences in the input fasta file.

In [14]:
per_protein_df = pd.DataFrame(model.pred_list[0]) # for the first protein
per_protein_df.head()

Unnamed: 0,residue,regression,classification,entropy,surprisal
0,M,0.037727,1,0.978943,-1.093094
1,D,0.24512,1,0.908195,0.82092
2,E,0.383294,2,0.81125,1.254376
3,N,0.27951,1,0.822857,1.229137
4,A,0.415725,2,0.727638,0.286058


In [15]:
pred_dfs = []
for pred, id in zip(model.pred_list, df["id"]):
    pred["id"] = id
    pred_dfs.append(pd.DataFrame(pred))
combined_df = pd.concat(pred_dfs, ignore_index=True)
combined_df

Unnamed: 0,residue,regression,classification,entropy,surprisal,id
0,M,0.037727,1,0.978943,-1.093094,Alb3-A3CT
1,D,0.245120,1,0.908195,0.820920,Alb3-A3CT
2,E,0.383294,2,0.811250,1.254376,Alb3-A3CT
3,N,0.279510,1,0.822857,1.229137,Alb3-A3CT
4,A,0.415725,2,0.727638,0.286058,Alb3-A3CT
...,...,...,...,...,...,...
1332,V,0.972561,2,0.265225,-0.683747,His-PknG_1-75
1333,R,-0.217082,1,0.572920,-0.200045,His-PknG_1-75
1334,R,-0.014122,1,0.614342,0.044282,His-PknG_1-75
1335,L,0.976999,2,0.283272,-0.422705,His-PknG_1-75


In [22]:
# saving the results to csv files if wanted.
#per_protein_df.to_csv("./bonomi_protein1_predictions.csv", index=False)
combined_df.to_csv("./bonomi_all_proteins_predictions.csv", index=False)

Use entropy and surprisal score to filter trough predictions. Below is an example where want to filter for residues in which the model is confident (entropy >= 0.3) but its unlikely to observe this value given the amino acids frustration distribution (-1 >= surprisal score OR surprisal score > 1). A surprisal score of 1 means that the predicted regression value is one standart deviations away from its AA mean. Feel free to play around with both scores :)

In [20]:
combined_df.loc[(combined_df["entropy"] <= 0.4) & ((combined_df["surprisal"] <= -1) | (combined_df["surprisal"] >= 1))]

Unnamed: 0,residue,regression,classification,entropy,surprisal,id
28,R,0.956004,2,0.393493,1.212142,Alb3-A3CT
292,I,0.946722,2,0.23075,-1.09596,UBact
350,Q,-0.172186,1,0.376972,1.037178,UBact
357,I,0.947645,2,0.321762,-1.093352,Nsp2_CtlIDR
591,I,0.937512,2,0.37843,-1.121969,p61_Hck
964,G,0.148351,1,0.394822,1.151247,Colicin_N_T_domain
966,S,0.116753,1,0.362394,1.066665,Colicin_N_T_domain


In [None]:
# change "R" to any amino acid single letter code to get its surprisal computation parameters
model.surprisal_dict["R"]

{'mean': -0.050906470495787476, 'std': 0.8306870505452515}