In [6]:
%load_ext autoreload
%autoreload 2

import os
import sys
import yaml
import json
import numpy as np
import pandas as pd
from pytorch_lightning import Trainer
from Bio import SeqIO

sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'src'))

from frustraiseq.data.dataloader import FunstrationDataModule
from frustraiseq.model.frustraiseq import FrustrAISeq

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# general info for script

## 1) Load Data
First we load a simple fasta file containing the headers/ids and their sequences and store it in a pd.DataFrame. Make sure to name the cols id and sequence

In [None]:
fasta_file_path = "../data/example_seqs.fasta"

In [10]:
sequences = []
ids = []
with open(fasta_file_path, 'r') as handle:
    for record in SeqIO.parse(handle, "fasta"):
        ids.append(record.id)
        sequences.append(str(record.seq))

df = pd.DataFrame({
    'id': ids,
    'sequence': sequences
})
df

Unnamed: 0,id,sequence
0,Seq1,SEQVENCE
1,Seq2,SEQVINCE


## 2) Load Model
Now lets load the config, pLM model, model checkpoint. If protT5 and FurstrAI-Seq are not already locally available download them and specify their location in the config

In [14]:
# load config.yaml corresponding to the model to be used.
with open(f"../src/frustraiseq/config/default_config.yml", 'r') as f:
    config = yaml.safe_load(f)
config["experiment_name"]

'FrustrAI-Seq_Prediction'

In [None]:
# Load pLM model and tokenizer, save them locally and update config with local path
"""
from transformers import T5EncoderModel, T5Tokenizer
encoder = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")

encoder.save_pretrained("./prot_t5_xl_half_uniref50-enc")
tokenizer.save_pretrained("./prot_t5_xl_half_uniref50-enc")
del encoder
del tokenizer
config["pLM_model"] = "./prot_t5_xl_half_uniref50-enc"
"""

In [None]:
# Load model checkpoint, save it locally and update config with local path
"""
from huggingface_hub import hf_hub_download
   
local_path = hf_hub_download(
    repo_id="leusch/FrustrAI-Seq",
    filename="FrustrAISeq_CW.ckpt",
    local_dir="./FrustrAI-Seq",
)
config["checkpoint_path"] = local_path
"""

In [None]:
# or set paths directly in the config if already downloaded
config["pLM_model"] = "../data/protT5"
config["checkpoint_path"] = "../data/it5_ABL_protT5_CW_LORA/best_val_model.ckpt"

In [None]:
model = FrustrAISeq.load_from_checkpoint(checkpoint_path=config["checkpoint_path"],
                                         config=config)

RANK -1: Loaded pLM model ../data/protT5
RANK -1: Using LoRA fine-tuning for ['q', 'k', 'v', 'o'] layers
trainable params: 1,967,104 || all params: 1,210,107,904 || trainable%: 0.1626
RANK -1: Using half precision.
RANK -1: Applying class weights for CrossEntropyLoss: [2.65750085, 0.68876299, 0.8533673]
RANK -1: Loaded surprisal dictionary.
RANK -1: Model initialized.


In [22]:
predict_data_module = FunstrationDataModule(
                                            config=config,
                                            inference_dataset=df,
                                            batch_size=config["batch_size"],
                                            num_workers=config["num_workers"],
                                            persistent_workers=False, 
                                            pin_memory=True
                                        )    

## 3) Inference

Lets define a Lightning trainer and run inference (prediction)

In [26]:
#adjust dataloader parameters as needed. 
trainer = Trainer(
    accelerator="auto",
    logger=False,
)

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [29]:
predictions = trainer.predict(model, predict_data_module)

Created test dataset for prediction
Test dataset size: 2 samples


/Users/janleusch/anaconda3/envs/biotrainer/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:428: Consider setting `persistent_workers=True` in 'predict_dataloader' to speed up the dataloader worker initialization.


Predicting: |          | 0/? [00:00<?, ?it/s]
During prediction sequence length limit will always be the longest sequence in the batch, so consider using batch size of 1 for inference to minimize memory usage.

Predicting DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00, 14.19it/s]


In [30]:
results = []
for pred_batch in predictions:
    if pred_batch is not None:
        results.extend(pred_batch)

df_output = pd.DataFrame(results)
df_output

Unnamed: 0,id,residue,frustration_index,frustration_class,entropy,surprisal
0,Seq1,"[S, E, Q, V, E, N, C, E]","[0.13900521, -0.2924483, -0.14515015, 0.571259...","[1, 0, 1, 2, 0, 0, 2, 1]","[0.79690963, 0.8030409, 0.83229244, 0.73676383...","[1.1143804, 0.43712324, 1.0951446, -1.876019, ..."
1,Seq2,"[S, E, Q, V, I, N, C, E]","[0.06939763, -0.2897056, -0.41023245, 0.557763...","[1, 0, 1, 2, 2, 1, 2, 1]","[0.7573504, 0.81135726, 0.74281573, 0.7773398,...","[0.9651208, 0.44044033, 0.5267884, -1.9161178,..."


In [None]:
df_output.to_csv("predictions.csv", index=False)

Use entropy and surprisal score to filter trough predictions. Below is an example where want to filter for residues in which the model is confident (entropy >= 0.5) but its unlikely to observe this value given the amino acids frustration distribution (-1 >= surprisal score OR surprisal score > 1). A surprisal score of 1 means that the predicted regression value is one standart deviations away from its AA mean. Feel free to play around with both scores :)

In [39]:
res_df = df_output.explode(["residue", "frustration_index", "frustration_class", "entropy", "surprisal"]).reset_index(drop=True)
res_df

Unnamed: 0,id,residue,frustration_index,frustration_class,entropy,surprisal
0,Seq1,S,0.139005,1,0.79691,1.11438
1,Seq1,E,-0.292448,0,0.803041,0.437123
2,Seq1,Q,-0.14515,1,0.832292,1.095145
3,Seq1,V,0.57126,2,0.736764,-1.876019
4,Seq1,E,-0.447937,0,0.728689,0.249072
5,Seq1,N,-0.351942,0,0.758936,0.267395
6,Seq1,C,0.785098,2,0.532843,-1.126765
7,Seq1,E,0.171754,1,0.97654,0.998537
8,Seq2,S,0.069398,1,0.75735,0.965121
9,Seq2,E,-0.289706,0,0.811357,0.44044


In [42]:
res_df.loc[(res_df["entropy"] < 0.5) & ((res_df["surprisal"] > 1) | (res_df["surprisal"] < -1))]

Unnamed: 0,id,residue,frustration_index,frustration_class,entropy,surprisal
12,Seq2,I,0.769839,2,0.461895,-1.595465
