In [1]:
import sys
import yaml
import json
import tqdm
import torch
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import pyarrow.parquet as pq
from pytorch_lightning import Trainer

from matplotlib import pyplot as plt
from transformers import T5Tokenizer, T5EncoderModel

sys.path.append('..')
sys.path.append('FrustraSeq')
from FrustraSeq.models.FrustraSeq import FrustraSeq
from FrustraSeq.dataloader import FrustrationDataModule

torch.set_float32_matmul_precision('medium')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
fasta_file_path = "../data/frustration/uniprotkb_human_AND_model_organism_9606_2025_12_31.fasta"

In [3]:
seqs = {}
with open(fasta_file_path, 'r') as f:
    fasta_data = f.read()
    for line in fasta_data.splitlines():
        if line.startswith(">"):
            header = line[1:].split('|')[1]
            seqs[header] = ""
        else:
            seqs[header] += line.strip()

In [4]:
#sort by length
seqs = dict(sorted(seqs.items(), key=lambda item: len(item[1])))

In [6]:
df = pd.DataFrame.from_dict(seqs, orient='index', columns=["sequence"]).reset_index().rename(columns={'index':'id'})
df

Unnamed: 0,id,sequence
0,P0DPR3,EI
1,P0DPI4,GTGG
2,P01858,TKPR
3,P0DOY5,GTTGT
4,P02729,CEHSHDGA
...,...,...
20415,P20929,MADDEDYEEVVEYYTEEVVYEEVPGETITKIYETTTTRTSDYEQSE...
20416,Q8NF91,MATSRGASRCPRDIANVMQRLQDEQEIVQKRTFTKWINSHLAKRKP...
20417,Q9H195,MQLLGLLSILWMLKSSPGATGTLSTATSTSHVTFPRAEATRTALSN...
20418,Q8WXI7,MLKPSGLPGSSSPTRSLMTGSRSTKATPEMDSGLTGATLSPKTSTG...


In [7]:
df["length"] = df["sequence"].apply(len)
df

Unnamed: 0,id,sequence,length
0,P0DPR3,EI,2
1,P0DPI4,GTGG,4
2,P01858,TKPR,4
3,P0DOY5,GTTGT,5
4,P02729,CEHSHDGA,8
...,...,...,...
20415,P20929,MADDEDYEEVVEYYTEEVVYEEVPGETITKIYETTTTRTSDYEQSE...,8525
20416,Q8NF91,MATSRGASRCPRDIANVMQRLQDEQEIVQKRTFTKWINSHLAKRKP...,8797
20417,Q9H195,MQLLGLLSILWMLKSSPGATGTLSTATSTSHVTFPRAEATRTALSN...,13477
20418,Q8WXI7,MLKPSGLPGSSSPTRSLMTGSRSTKATPEMDSGLTGATLSPKTSTG...,14507


In [10]:
df.loc[df['length'] > 4096]

Unnamed: 0,id,sequence,length
20343,Q8TD57,MGATGRLELTLAAPPHPGPAFQRSKARETQGEEEGSEMQIAKSDSI...,4116
20344,P78527,MAGSGAGVRCSLLRLQETLSAADRCGAALAGHQLIRGLGQECVLSS...,4128
20345,Q9C0G6,MTFRATDSEFDLTNIEEYAENSALSRLNNIKAKQRVSYVTSTENES...,4158
20346,Q8TCU4,MEPEDLPWPGELEEEEEEEEEEEEEEEEAAAAAAANVDDVVVVEEV...,4168
20347,Q86WI1,MGHLWLLGIWGLCGLLLCAADPSTDGSQIIPKVTEIIPKYGSINGA...,4243
...,...,...,...
20415,P20929,MADDEDYEEVVEYYTEEVVYEEVPGETITKIYETTTTRTSDYEQSE...,8525
20416,Q8NF91,MATSRGASRCPRDIANVMQRLQDEQEIVQKRTFTKWINSHLAKRKP...,8797
20417,Q9H195,MQLLGLLSILWMLKSSPGATGTLSTATSTSHVTFPRAEATRTALSN...,13477
20418,Q8WXI7,MLKPSGLPGSSSPTRSLMTGSRSTKATPEMDSGLTGATLSPKTSTG...,14507


In [6]:
#in every row replace U, O, B, Z with X
df["sequence"] = df["sequence"].apply(lambda x: x.replace("U", "X").replace("O", "X").replace("B", "X").replace("Z", "X"))
df

Unnamed: 0,id,sequence
0,P0DPR3,EI
1,P0DPI4,GTGG
2,P01858,TKPR
3,P0DOY5,GTTGT
4,P02729,CEHSHDGA
...,...,...
20415,P20929,MADDEDYEEVVEYYTEEVVYEEVPGETITKIYETTTTRTSDYEQSE...
20416,Q8NF91,MATSRGASRCPRDIANVMQRLQDEQEIVQKRTFTKWINSHLAKRKP...
20417,Q9H195,MQLLGLLSILWMLKSSPGATGTLSTATSTSHVTFPRAEATRTALSN...
20418,Q8WXI7,MLKPSGLPGSSSPTRSLMTGSRSTKATPEMDSGLTGATLSPKTSTG...


In [7]:
df = df.sample(1000, random_state=42)

In [8]:
# load config
with open(f"../data/it5_ABL_protT5_CW_LORA/config.yaml", 'r') as f:
    config = yaml.safe_load(f)
config["experiment_name"]

'it5_ABL_protT5_CW_LORA'

In [9]:
# either provide path to pretrained model or 
# set to huggingface model name (e.g. "Rostlab/prot_t5_xl_uniref50" for protT5-xl) 
config["pLM_model"] = "../data/protT5"

In [10]:
config["max_seq_length"] = len(seqs[list(seqs.keys())[-1]])
config["max_seq_length"]

34350

In [11]:
model = FrustraSeq.load_from_checkpoint(checkpoint_path=f"../data/{config['experiment_name']}/best_val_model.ckpt",
                                        config=config)

c:\Users\Jan\.conda\envs\mvtcr_plus\Lib\site-packages\pytorch_lightning\utilities\migration\utils.py:56: The loaded checkpoint was produced with Lightning v2.5.5, which is newer than your current Lightning version: v2.5.1


Using LoRA fine-tuning for ['q', 'k', 'v', 'o'] layers
trainable params: 1,967,104 || all params: 1,210,107,904 || trainable%: 0.1626
Using class-weighted cross-entropy loss.
RANK -1: Model initialized.


In [12]:
# adding the surprisal dictionary which is used to compute the surprisal feature during inference based on 
# precomputed values (for each aa) from the train set
with open('../data/frustration/reg_heuristic.json', 'r') as f:
    model.surprisal_dict = json.load(f)
model.surprisal_dict["A"]

{'mean': 0.24633155516241328, 'std': 0.5921655729624687}

In [13]:
trainer = Trainer(accelerator='gpu', 
                  precision="bf16-mixed") # use 'gpu' instead of 'mps' on cuda enabled devices or 'cpu' for cpu only

Using bfloat16 Automatic Mixed Precision (AMP)
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
predict_dataloader = FrustrationDataModule(df=df,
                                        max_seq_length=df["sequence"].str.len().max(),
                                        batch_size=1,
                                        num_workers=1,
                                        persistent_workers=True,)

trainer.predict(model, predict_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Loaded 1000 sequences for prediction.
Created test dataset for prediction
Test dataset size: 1000 samples


c:\Users\Jan\.conda\envs\mvtcr_plus\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0:   0%|          | 0/1000 [00:00<?, ?it/s]Predicting batch with max sequence length: 872
Predicting DataLoader 0:   0%|          | 1/1000 [00:00<07:45,  2.15it/s]Predicting batch with max sequence length: 513
Predicting DataLoader 0:   0%|          | 2/1000 [00:00<04:34,  3.63it/s]Predicting batch with max sequence length: 514
Predicting DataLoader 0:   0%|          | 3/1000 [00:00<03:28,  4.77it/s]



Predicting batch with max sequence length: 259
Predicting DataLoader 0:   0%|          | 4/1000 [00:00<02:48,  5.90it/s]Predicting batch with max sequence length: 654
Predicting DataLoader 0:   0%|          | 5/1000 [00:00<02:34,  6.42it/s]Predicting batch with max sequence length: 311
Predicting DataLoader 0:   1%|          | 6/1000 [00:00<02:17,  7.21it/s]Predicting batch with max sequence length: 794
Predicting DataLoader 0:   1%|          | 7/1000 [00:00<02:16,  7.30it/s]Predicting batch with max sequence length: 272
Predicting DataLoader 0:   1%|          | 8/1000 [00:01<02:04,  7.94it/s]Predicting batch with max sequence length: 378
Predicting DataLoader 0:   1%|          | 9/1000 [00:01<01:57,  8.43it/s]Predicting batch with max sequence length: 110
Predicting DataLoader 0:   1%|          | 10/1000 [00:01<01:49,  9.03it/s]Predicting batch with max sequence length: 246
Predicting DataLoader 0:   1%|          | 11/1000 [00:01<01:43,  9.56it/s]Predicting batch with max sequence len

In [None]:
bs = 128

for i in tqdm.tqdm(range(0, len(df), bs)):
    batch_df = df.iloc[i:i+bs]
    predict_dataloader = FrustrationDataModule(df=batch_df,
                                                max_seq_length=batch_df["sequence"].str.len().max(),
                                                batch_size=bs,
                                                num_workers=1,
                                                persistent_workers=True,)
    model.max_seq_length = batch_df["sequence"].str.len().max()
    try:
        trainer.predict(model, predict_dataloader)
    except Exception as e:
        print(f"Model max_seq_length: {model.max_seq_length}")
        print(f"Error at batch starting with index {i}: {e}")