In [7]:
%load_ext autoreload
%autoreload 2

import sys
import tqdm
import torch
import peft
import numpy as np
import pandas as pd
import scanpy as sc
import torch.nn as nn
import pytorch_lightning as pl
import pyarrow.parquet as pq
import seaborn as sns
import matplotlib.pyplot as plt

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from transformers import T5Tokenizer, T5EncoderModel
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
from peft import get_peft_config, PeftModel, PeftConfig, LoraConfig, get_peft_model

sys.path.append('..')
from pLMtrainer.dataloader import FrustrationDataModule
from pLMtrainer.models import FrustrationCNN

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### LoRA Debug

In [2]:
exp_name = "it3_lora_debug"
parquet_path = "../data/frustration/v4_frustration.parquet.gzip"
max_seq_length = 50
batch_size = 10
num_workers = 1
cath_sampling_n = 1  # None for no sampling

In [3]:
device = "mps"
#pLM_model = "../data/prostT5"
pLM_model = "../data/protT5"
prefix_prostT5 = "<AA2fold>"
tokenizer = T5Tokenizer.from_pretrained(pLM_model, do_lower_case=False, max_length=max_seq_length)
encoder = T5EncoderModel.from_pretrained(pLM_model).to(device)

In [4]:
data_module = FrustrationDataModule(df=None,
                                    parquet_path=parquet_path, 
                                    batch_size=batch_size, 
                                    max_seq_length=max_seq_length, 
                                    num_workers=num_workers, 
                                    persistent_workers=True,
                                    sample_size=None,
                                    cath_sampling_n=cath_sampling_n)

In [5]:
class FrustrationCNN(pl.LightningModule):
    def __init__(self, 
                 input_dim=1024, 
                 hidden_dims=[64, 10], 
                 output_dim=4, # 3 classes + 1 for regression 
                 dropout=0.15,
                 max_seq_length=512,
                 precision="full",
                 use_loraFT=None, # either "QK" or "all"
                 pLM_model="Rostlab/ProtT5",
                 prefix_prostT5="<AA2fold>",
                 no_label_token=-100,):
        super(FrustrationCNN, self).__init__()

        self.tokenizer = T5Tokenizer.from_pretrained(pLM_model, do_lower_case=False, max_length=max_seq_length)
        self.encoder = T5EncoderModel.from_pretrained(pLM_model).to(self.device)
        self.plm_model = pLM_model
        self.prefix_prostT5 = prefix_prostT5
        self.max_seq_length = max_seq_length  # seq len will be max_seq_length-1 later since we prepend the prefix
        self.precision = precision
        self.use_loraFT = use_loraFT

        if self.precision == "half":
            self.encoder.half()
            #self.CNN.half()
            print("Using half precision")

        if self.use_loraFT == "all":
            print("Using LoRA fine-tuning for all layers")
            peft_config = LoraConfig(
                task_type="FEATURE_EXTRACTION",
                inference_mode=False,
                r=8,
                lora_alpha=32,
                target_modules=["q", "k", "v", "o", "wi", "wo", "w1", "w2", "w3", "fc1", "fc2", "fc3"],
                #lora_dropout=0.1,
            )
            self.encoder.train()
            self.encoder = get_peft_model(self.encoder, peft_config)
            self.encoder.print_trainable_parameters()
        elif self.use_loraFT == "QK":
            print("Using LoRA fine-tuning for Q and K layers only")
            peft_config = LoraConfig(
                task_type="FEATURE_EXTRACTION",
                inference_mode=False,
                r=8,
                lora_alpha=32,
                target_modules=["q", "k"],
                #lora_dropout=0.1,
            )
            self.encoder.train()
            self.encoder = get_peft_model(self.encoder, peft_config)
            self.encoder.print_trainable_parameters()
        else:
            self.encoder.eval()  # Freeze the encoder   

        # https://github.com/RSchmirler/ProtT5-EvoTuning/blob/main/notebook/PT5_EvoTuning.ipynb 
        # lora modification
        #peft_config = LoraConfig(
        #    r=4, lora_alpha=1, bias="all", target_modules=["q","k","v","o"], task_type = "SEQ_2_SEQ_LM",
        #)

        #https://github.com/mheinzinger/ProstT5/blob/main/scripts/predict_3Di_encoderOnly.py
        self.CNN = nn.Sequential(
            nn.Conv2d(input_dim, hidden_dims[0], kernel_size=(7, 1), padding=(3, 0)),  # 7x64
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Conv2d(hidden_dims[0], hidden_dims[1], kernel_size=(7, 1), padding=(3, 0))
        )

        self.cls_head = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dims[1], output_dim-1),
        ) 

        self.reg_head = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dims[1], 1),
        )

        self.mse_loss_fn = nn.MSELoss()
        self.ce_loss_fn = nn.CrossEntropyLoss(ignore_index=no_label_token) # TODO look at weight param for class imbalance

    def forward(self, full_seq):
        #start_time = time.time()
        if "prostt5" in self.plm_model.lower():
            full_seq = [self.prefix_prostT5 + " " + " ".join(seq) for seq in full_seq]  # Add spaces between amino acids and prefix
        else:
            full_seq = [" ".join(seq) for seq in full_seq]

        ids = self.tokenizer.batch_encode_plus(full_seq, 
                                               add_special_tokens=True, 
                                               max_length=self.max_seq_length,
                                               padding="max_length",
                                               truncation="longest_first",
                                               return_tensors='pt'
                                               ).to(self.device)
        
        if self.use_loraFT == "all" or self.use_loraFT == "QK":
            embedding_rpr = self.encoder(
                input_ids=ids.input_ids, 
                attention_mask=ids.attention_mask
            )
        else:
            with torch.no_grad():
                embedding_rpr = self.encoder(
                    input_ids=ids.input_ids, 
                    attention_mask=ids.attention_mask
                )
                
        if "prostt5" in self.plm_model.lower():
            embeddings = embedding_rpr.last_hidden_state[:, 1:].float() # remove the aa token bos and eos and bring to shape
        else:
            embeddings = embedding_rpr.last_hidden_state.float() # remove the aa token bos and bring to shape

        embeddings = embeddings.permute(0, 2, 1).unsqueeze(-1)  # (batch_size, input_dim, seq_length, 1)

        res = self.CNN(embeddings).squeeze(-1).permute(0, 2, 1)  # (batch_size, seq_length, output_dim)
        cls_res = self.cls_head(res)
        reg_res = self.reg_head(res)
        #end_time = time.time()
        #print(f"Forward pass time: {end_time - start_time} seconds")
        return cls_res, reg_res
    
    def general_step(self, batch, stage):
        full_seq, res_mask, frst_vals, frst_classes = batch
        if res_mask.sum() == 0:
            print(f"{stage.capitalize()} batch with no valid residues - skipping") 
            return None  # Skip this batch
        
        cls_preds, reg_preds = self.forward(full_seq)
        cls_preds = cls_preds.squeeze(-1)
        reg_preds = reg_preds.squeeze(-1)

        mse_loss = self.mse_loss_fn(reg_preds[res_mask], frst_vals[res_mask]) # shape (batch_size, 1)
        ce_loss = self.ce_loss_fn(cls_preds.flatten(0, 1), frst_classes.flatten()) # shape (batch_size, n_classes(3))
        loss = mse_loss + ce_loss

        self.log(f'{stage}_mse_loss', mse_loss, on_step=(stage=='train'), on_epoch=True, prog_bar=False)
        self.log(f'{stage}_ce_loss', ce_loss, on_step=(stage=='train'), on_epoch=True, prog_bar=False)
        self.log(f'{stage}_loss', loss, on_step=(stage=='train'), on_epoch=True, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self.general_step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        return self.general_step(batch, 'val')
        
    def test_step(self, batch, batch_idx):
        full_seq, res_mask, frst_vals, frst_classes = batch
        if res_mask.sum() == 0:
            print("Test batch with no valid residues - skipping") 
            return None  # Skip this batch
        
        cls_preds, reg_preds = self.forward(full_seq)
        cls_preds = cls_preds.squeeze(-1)
        reg_preds = reg_preds.squeeze(-1)

        mse_loss = self.mse_loss_fn(reg_preds[res_mask], frst_vals[res_mask]) # shape (batch_size, 1)
        ce_loss = self.ce_loss_fn(cls_preds.flatten(0,1), frst_classes.flatten()) # shape (batch_size, 20)
        loss = mse_loss + ce_loss

        self.log('test_mse_loss', mse_loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log('test_ce_loss', ce_loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)

        self.test_dict["full_seqs"].append(full_seq.cpu().numpy())
        self.test_dict["masks"].append(res_mask.cpu().numpy())

        self.test_dict["regr_preds"].append(reg_preds.cpu().numpy())
        self.test_dict["cls_preds"].append(cls_preds.cpu().numpy())
        self.test_dict["regr_targets"].append(frst_vals.cpu().numpy())
        self.test_dict["cls_targets"].append(frst_classes.cpu().numpy())

        self.test_dict["masked_regr_preds"].append(reg_preds[res_mask].flatten().cpu().numpy())
        self.test_dict["masked_cls_preds"].append(torch.argmax(cls_preds, dim=-1)[res_mask].flatten().cpu().numpy())
        self.test_dict["masked_regr_targets"].append(frst_vals[res_mask].flatten().cpu().numpy())
        self.test_dict["masked_cls_targets"].append(frst_classes[res_mask].flatten().cpu().numpy())

        return loss
    
    def predict_step(self, batch, batch_idx):
        full_seq, _, _, _ = batch
        preds = self.forward(full_seq)
        self.preds_list.append(preds.cpu().numpy())

    def configure_optimizers(self):
        if self.use_loraFT:
            # Include both CNN and LoRA adapter parameters
            params = list(self.CNN.parameters()) + list(self.encoder.parameters())
            optimizer = torch.optim.Adam(params, lr=1e-3)
        else:
            optimizer = torch.optim.Adam(self.CNN.parameters(), lr=1e-3)
        return optimizer

    def on_train_end(self):
        if self.use_loraFT is not None:
            # save the lora adapters
            lora_save_path = f"./{exp_name}/lora_adapters"
            print(f"Saving LoRA adapters to {lora_save_path}")
            self.encoder.save_pretrained(lora_save_path)

    def on_test_epoch_start(self):
        self.test_dict = {"full_seqs": [],
                          "masks": [],
                          "regr_preds": [], 
                          "cls_preds": [], 
                          "regr_targets": [], 
                          "cls_targets": [],
                          "masked_regr_preds": [], 
                          "masked_cls_preds": [], 
                          "masked_regr_targets": [], 
                          "masked_cls_targets": []}

    def on_test_epoch_end(self):
        #concate the batches
        self.test_dict["full_seqs"] = np.concatenate(self.test_dict["full_seqs"])
        self.test_dict["masks"] = np.concatenate(self.test_dict["masks"])
        self.test_dict["regr_preds"] = np.concatenate(self.test_dict["regr_preds"])
        self.test_dict["cls_preds"] = np.concatenate(self.test_dict["cls_preds"])
        self.test_dict["regr_targets"] = np.concatenate(self.test_dict["regr_targets"])
        self.test_dict["cls_targets"] = np.concatenate(self.test_dict["cls_targets"])
        self.test_dict["masked_regr_preds"] = np.concatenate(self.test_dict["masked_regr_preds"])
        self.test_dict["masked_cls_preds"] = np.concatenate(self.test_dict["masked_cls_preds"])
        self.test_dict["masked_regr_targets"] = np.concatenate(self.test_dict["masked_regr_targets"])
        self.test_dict["masked_cls_targets"] = np.concatenate(self.test_dict["masked_cls_targets"])

    def on_predict_start(self):
        self.preds_list = []
    
    def on_predict_end(self):
        self.preds_list = np.concatenate(self.preds_list)

    def save_preds_targets(self, path="./preds_targets.npz"):
        np.savez_compressed(path, **self.test_dict)

    @staticmethod
    def suggest_params():
        #TODO model selection
        pass



In [8]:
early_stop = EarlyStopping(monitor="val_loss",
                           patience=5,
                           mode='min',
                           verbose=True)
checkpoint = ModelCheckpoint(monitor="val_loss",
                             dirpath=f"./{exp_name}",
                             filename=f"train_best",
                             save_top_k=1,
                             mode='min',
                             save_weights_only=True)
logger = CSVLogger(f"./{exp_name}", name="train_logs")

In [9]:
model = FrustrationCNN(input_dim=1024, 
                       hidden_dims=[64, 10], 
                       output_dim=4, # 3 classes + 1 for regression 
                       dropout=0.15,
                       max_seq_length=max_seq_length,
                       precision="full", 
                       pLM_model="../data/protT5",
                       prefix_prostT5="<AA2fold>",
                       use_loraFT="all",
                       no_label_token=-100,)

Using LoRA fine-tuning for all layers
trainable params: 10,616,832 || all params: 1,218,758,656 || trainable%: 0.8711


In [10]:
model.encoder

PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): T5EncoderModel(
      (shared): Embedding(128, 1024)
      (encoder): T5Stack(
        (embed_tokens): Embedding(128, 1024)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): lora.Linear(
                    (base_layer): Linear(in_features=1024, out_features=4096, bias=False)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1024, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=8, out_features=4096, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                 

In [11]:
trainer = Trainer(accelerator='mps', # gpu
                  devices=-1, # 4 for one node on haicore
                  max_epochs=1,
                  logger=logger,
                  log_every_n_steps=5,
                  #val_check_interval=1, #3500 batches in total 
                  callbacks=[early_stop, checkpoint],
                  #precision="16-mixed",
                  gradient_clip_val=1,
                  enable_progress_bar=True,
                  deterministic=False, # for reproducibility disable on cluster 
                  #num_sanity_val_steps=0,
                  #accumulate_grad_batches=2, # if batch size gets too small --> test on H100/A100
                  )

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [12]:
trainer.fit(model, datamodule=data_module)

Loaded 562 samples from ../data/frustration/v4_frustration.parquet.gzip
Created train/val/test masks
Initialized res_idx_mask and frst_vals tensors
Populated res_idx_mask and frst_vals tensors
Created train dataset
Created val dataset
Created test dataset
Train/Val/Test split: 482/30/50 samples


/Users/janleusch/anaconda3/envs/biotrainer/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:751: Checkpoint directory /Users/janleusch/Documents/phd/pLMtrainer/pLMtrainer/notebooks/it3_lora_debug exists and is not empty.

  | Name        | Type                          | Params | Mode 
----------------------------------------------------------------------
0 | encoder     | PeftModelForFeatureExtraction | 1.2 B  | train
1 | CNN         | Sequential                    | 463 K  | train
2 | cls_head    | Sequential                    | 33     | train
3 | reg_head    | Sequential                    | 11     | train
4 | mse_loss_fn | MSELoss                       | 0      | train
5 | ce_loss_fn  | CrossEntropyLoss              | 0      | train
----------------------------------------------------------------------
2.4 M     Trainable params
1.2 B     Non-trainable params
1.2 B     Total params
4,842.285 Total estimated model params size (MB)
936       Modules in tr

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/janleusch/anaconda3/envs/biotrainer/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


                                                                           

/Users/janleusch/anaconda3/envs/biotrainer/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 49/49 [01:37<00:00,  0.50it/s, v_num=9, train_loss_step=1.430, val_loss=1.240, train_loss_epoch=1.560]

Metric val_loss improved. New best score: 1.241
`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 49/49 [01:44<00:00,  0.47it/s, v_num=9, train_loss_step=1.430, val_loss=1.240, train_loss_epoch=1.560]
Saving LoRA adapters to ./it3_lora_debug/lora_adapters


In [17]:
model.encoder.state_dict()

OrderedDict([('base_model.model.shared.weight',
              tensor([[-2.6719e+00, -7.8906e-01,  7.1094e-01,  ..., -1.0391e+00,
                        2.9883e-01,  3.3008e-01],
                      [-1.5500e+01,  6.5938e+00, -1.0300e-03,  ..., -6.0938e+00,
                       -4.3438e+00,  4.1250e+00],
                      [-1.4188e+01, -2.5312e+00,  8.5625e+00,  ..., -3.8906e+00,
                       -5.7188e+00,  1.4355e-01],
                      ...,
                      [-1.4312e+01, -1.1719e+00,  5.7500e+00,  ..., -4.3750e+00,
                       -6.2812e+00,  2.5156e+00],
                      [-1.4688e+01, -1.5000e+00,  7.2188e+00,  ..., -4.0938e+00,
                       -5.5000e+00,  1.3125e+00],
                      [-1.1719e+00, -3.4375e-01,  1.0781e+00,  ..., -3.4570e-01,
                        1.3184e-01,  4.0820e-01]])),
             ('base_model.model.encoder.embed_tokens.weight',
              tensor([[-2.6719e+00, -7.8906e-01,  7.1094e-01,  ..., -1.039

In [19]:
model.encoder.state_dict()["base_model.model.encoder.block.1.layer.0.SelfAttention.q.base_layer.weight"]

tensor([[-0.0060,  0.0209, -0.0243,  ...,  0.0023,  0.0087,  0.0052],
        [ 0.0060,  0.0203, -0.0149,  ...,  0.0077,  0.0052,  0.0031],
        [ 0.0087, -0.0002,  0.0015,  ...,  0.0026,  0.0089, -0.0002],
        ...,
        [-0.0021,  0.0024, -0.0078,  ..., -0.0067,  0.0193,  0.0054],
        [-0.0060,  0.0306, -0.0054,  ...,  0.0059, -0.0044, -0.0133],
        [ 0.0127, -0.0177, -0.0060,  ...,  0.0016, -0.0042, -0.0140]])

In [22]:
model.encoder

PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): T5EncoderModel(
      (shared): Embedding(128, 1024)
      (encoder): T5Stack(
        (embed_tokens): Embedding(128, 1024)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): lora.Linear(
                    (base_layer): Linear(in_features=1024, out_features=4096, bias=False)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1024, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=8, out_features=4096, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                 

In [24]:
load_model = FrustrationCNN.load_from_checkpoint(checkpoint_path=f"./{exp_name}/train_best.ckpt",
                                                input_dim=1024, 
                                                hidden_dims=[64, 10], 
                                                output_dim=4, # 3 classes + 1 for regression 
                                                dropout=0.15,
                                                max_seq_length=max_seq_length,
                                                precision="full", 
                                                pLM_model="../data/protT5",
                                                prefix_prostT5="<AA2fold>",
                                                use_loraFT="QK",
                                                no_label_token=-100,)

Using LoRA fine-tuning for Q and K layers only
trainable params: 1,966,080 || all params: 1,210,107,904 || trainable%: 0.1625


In [29]:
load_model.encoder

PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): T5EncoderModel(
      (shared): Embedding(128, 1024)
      (encoder): T5Stack(
        (embed_tokens): Embedding(128, 1024)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): lora.Linear(
                    (base_layer): Linear(in_features=1024, out_features=4096, bias=False)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1024, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=8, out_features=4096, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                 

In [26]:
full_seq = ["S", "E", "Q", "V", "E", "C", "E"]

In [42]:
ids = load_model.tokenizer.batch_encode_plus(full_seq, 
                                        add_special_tokens=True, 
                                        max_length=max_seq_length,
                                        padding="max_length",
                                        truncation="longest_first",
                                        return_tensors='pt'
                                        ).to(load_model.device)
        

embedding_rpr = load_model.encoder(
    input_ids=ids.input_ids, 
    attention_mask=ids.attention_mask
).last_hidden_state.float()

In [43]:
embedding_rpr

tensor([[[ 0.5221, -0.2095, -0.0000,  ..., -0.0000, -0.0414, -0.0365],
         [-0.0354, -0.0668,  0.0000,  ..., -0.0215, -0.0181,  0.0729],
         [ 0.0600, -0.0838, -0.3391,  ...,  0.2179,  0.0000, -0.0508],
         ...,
         [ 0.1201, -0.3043, -0.5045,  ..., -0.0762, -0.1964,  0.0360],
         [-0.0000, -0.0000, -0.2776,  ...,  0.0000, -0.3510, -0.0175],
         [-0.2317, -0.3814, -0.3399,  ..., -0.0044, -0.5797, -0.0097]],

        [[ 0.3358, -0.2482, -0.0000,  ...,  0.0855, -0.2614, -0.1635],
         [-0.0385, -0.0666,  0.0061,  ..., -0.0108, -0.0070,  0.0716],
         [ 0.0909, -0.1780,  0.0624,  ...,  0.2201, -0.0156, -0.0834],
         ...,
         [ 0.1357, -0.4268, -0.2088,  ...,  0.1007, -0.1573,  0.1208],
         [ 0.0724, -0.3627, -0.4703,  ...,  0.0591, -0.1821,  0.0336],
         [ 0.0540, -0.3383, -0.4766,  ...,  0.0757, -0.2415, -0.0931]],

        [[ 0.2456, -0.3815, -0.4543,  ..., -0.0539, -0.1558, -0.1461],
         [-0.0414, -0.1043,  0.0910,  ..., -0

In [30]:
new_model = FrustrationCNN(input_dim=1024, 
                            hidden_dims=[64, 10], 
                            output_dim=4, # 3 classes + 1 for regression 
                            dropout=0.15,
                            max_seq_length=max_seq_length,
                            precision="full", 
                            pLM_model="../data/protT5",
                            prefix_prostT5="<AA2fold>",
                            use_loraFT="QK",
                            no_label_token=-100,)

Using LoRA fine-tuning for Q and K layers only
trainable params: 1,966,080 || all params: 1,210,107,904 || trainable%: 0.1625


In [33]:
ids = new_model.tokenizer.batch_encode_plus(full_seq, 
                                        add_special_tokens=True, 
                                        max_length=max_seq_length,
                                        padding="max_length",
                                        truncation="longest_first",
                                        return_tensors='pt'
                                        ).to(new_model.device)
        

new_embedding_rpr = new_model.encoder(
    input_ids=ids.input_ids, 
    attention_mask=ids.attention_mask
).last_hidden_state.float()

In [34]:
new_embedding_rpr

tensor([[[ 0.0756, -0.3897, -0.4722,  ..., -0.2083, -0.0000, -0.0803],
         [-0.0309, -0.0391,  0.0785,  ..., -0.0131,  0.0172,  0.0517],
         [ 0.1087, -0.2670, -0.1929,  ...,  0.1103, -0.0261, -0.1431],
         ...,
         [-0.0801, -0.2283, -0.4148,  ..., -0.1586, -0.2432,  0.0471],
         [ 0.0455, -0.2686, -0.4092,  ..., -0.0045, -0.3129, -0.0496],
         [ 0.0164, -0.1842, -0.4229,  ...,  0.0619, -0.3047,  0.0650]],

        [[ 0.1457, -0.2528, -0.2619,  ..., -0.0562, -0.3641, -0.1093],
         [-0.0845, -0.1071,  0.0384,  ...,  0.0234,  0.0639,  0.0432],
         [ 0.1488, -0.1847, -0.1277,  ...,  0.2352,  0.1047, -0.2298],
         ...,
         [ 0.0774, -0.3723, -0.2948,  ..., -0.0000, -0.1990, -0.0000],
         [-0.0000, -0.2847,  0.0000,  ..., -0.0656, -0.1496, -0.1729],
         [-0.1205, -0.2443, -0.2521,  ..., -0.0801, -0.1778, -0.0302]],

        [[ 0.0000, -0.1095, -0.2575,  ..., -0.0693, -0.3859, -0.2565],
         [-0.0103, -0.0423,  0.0737,  ..., -0

In [35]:
base_model = FrustrationCNN(input_dim=1024, 
                            hidden_dims=[64, 10], 
                            output_dim=4, # 3 classes + 1 for regression 
                            dropout=0.15,
                            max_seq_length=max_seq_length,
                            precision="full", 
                            pLM_model="../data/protT5",
                            prefix_prostT5="<AA2fold>",
                            use_loraFT=None,
                            no_label_token=-100,)

In [44]:
ids = base_model.tokenizer.batch_encode_plus(full_seq, 
                                        add_special_tokens=True, 
                                        max_length=max_seq_length,
                                        padding="max_length",
                                        truncation="longest_first",
                                        return_tensors='pt'
                                        ).to(base_model.device)
        

base_embedding_rpr = base_model.encoder(
    input_ids=ids.input_ids, 
    attention_mask=ids.attention_mask
).last_hidden_state.float()

In [45]:
base_embedding_rpr

tensor([[[ 0.2118, -0.2282, -0.3910,  ..., -0.1777, -0.2262, -0.1678],
         [-0.0765, -0.0854,  0.0859,  ...,  0.0147,  0.0367,  0.0804],
         [ 0.0152, -0.2382, -0.1728,  ...,  0.0530, -0.0467, -0.0430],
         ...,
         [-0.0863, -0.2326, -0.3657,  ..., -0.0764, -0.2000, -0.0448],
         [-0.0863, -0.2326, -0.3657,  ..., -0.0764, -0.2000, -0.0448],
         [-0.0863, -0.2326, -0.3657,  ..., -0.0764, -0.2000, -0.0448]],

        [[ 0.2630, -0.2531, -0.0511,  ..., -0.0465, -0.4211, -0.1257],
         [-0.0713, -0.0862,  0.0057,  ...,  0.0311,  0.0302,  0.0646],
         [-0.0066, -0.2809, -0.0826,  ...,  0.1239, -0.0013, -0.1679],
         ...,
         [-0.0783, -0.2493, -0.2310,  ..., -0.0557, -0.1860, -0.0014],
         [-0.0783, -0.2493, -0.2310,  ..., -0.0557, -0.1860, -0.0014],
         [-0.0783, -0.2493, -0.2310,  ..., -0.0557, -0.1860, -0.0014]],

        [[ 0.1627, -0.2232, -0.2073,  ..., -0.0837, -0.3656, -0.1257],
         [-0.0553, -0.0819,  0.0635,  ...,  0

### Config Debug

In [8]:
config = {
    "experiment_name": "it3-5_loraAll_small",
    "parquet_path": "../data/frustration/v4_frustration.parquet.gzip",
    "cath_sampling_n": 2,  # None for no sampling
    "batch_size": 64, # 24 for QK, 16 for all modules for FT; 64 maybe more for no FT
    "num_workers": 10,
    "input_dim": 1024,
    "hidden_dims": [64, 10],
    "output_dim": 4, # 3 classes + 1 for regression 
    "dropout": 0.15,
    "max_seq_length": 512,
    "precision": "full",
    "finetune": False,
    "pLM_model": "../data/protT5",
    "prefix_prostT5": "<AA2fold>",
    "no_label_token": -100,
    "lora_r": 4,
    "lora_alpha": 1,
    "lora_modules": ["q", "k"], #["q", "k", "v", "o", "wi", "wo", "w1", "w2", "w3", "fc1", "fc2", "fc3"],
    "ce_weighting": None,
    "notes": "",   
}

if config["precision"] == "full":
    torch.set_float32_matmul_precision("high")
    trainer_precision = "32"
elif config["precision"] == "half":
    torch.set_float32_matmul_precision("medium")
    trainer_precision = "16-mixed"

if config["finetune"]:
    find_unused = False
else:
    find_unused = True

In [9]:
data_module = FrustrationDataModule(df=None,
                                    parquet_path=config["parquet_path"], 
                                    batch_size=config["batch_size"], 
                                    max_seq_length=config["max_seq_length"], 
                                    num_workers=config["num_workers"], 
                                    persistent_workers=True,
                                    sample_size=None,
                                    cath_sampling_n=config["cath_sampling_n"])

model = FrustrationCNN(config=config)

early_stop = EarlyStopping(monitor="val_loss",
                           patience=5,
                           mode='min',
                           verbose=True)
checkpoint = ModelCheckpoint(monitor="val_loss",
                             dirpath=f"./{config['experiment_name']}",
                             filename=f"train_best",
                             save_top_k=1,
                             mode='min',
                             save_weights_only=True)

logger = CSVLogger(f"./{config['experiment_name']}", name="train_logs")

In [10]:
trainer = Trainer(accelerator='auto', # gpu
                  devices=-1, # 4 for one node on haicore
                  #strategy=DDPStrategy(find_unused_parameters=find_unused),
                  max_epochs=2,
                  logger=logger,
                  log_every_n_steps=10,
                  val_check_interval=0.2, #3500 batches in total 
                  callbacks=[early_stop, checkpoint],
                  precision=trainer_precision,
                  gradient_clip_val=1,
                  enable_progress_bar=False,
                  deterministic=False, # for reproducibility disable on cluster
                  accumulate_grad_batches=4, # if batch size gets too small
                  )
trainer.fit(model, datamodule=data_module)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Loaded 1124 samples from ../data/frustration/v4_frustration.parquet.gzip
Created train/val/test masks
Initialized res_idx_mask and frst_vals tensors
Populated res_idx_mask and frst_vals tensors
Created train dataset
Created val dataset
Created test dataset
Train/Val/Test split: 964/60/100 samples


/Users/janleusch/anaconda3/envs/biotrainer/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:751: Checkpoint directory /Users/janleusch/Documents/phd/pLMtrainer/pLMtrainer/notebooks/it3-5_loraAll_small exists and is not empty.

  | Name        | Type             | Params | Mode 
---------------------------------------------------------
0 | encoder     | T5EncoderModel   | 1.2 B  | eval 
1 | CNN         | Sequential       | 463 K  | train
2 | cls_head    | Sequential       | 33     | train
3 | reg_head    | Sequential       | 11     | train
4 | mse_loss_fn | MSELoss          | 0      | train
5 | ce_loss_fn  | CrossEntropyLoss | 0      | train
---------------------------------------------------------
1.2 B     Trainable params
0         Non-trainable params
1.2 B     Total params
4,834.421 Total estimated model params size (MB)
15        Modules in train mode
439       Modules in eval mode


Starting validation 2025-10-28 08:42:03
Ending validation 2025-10-28 08:42:54




Starting validation 2025-10-28 08:47:21


Metric val_loss improved. New best score: 2.130


Ending validation 2025-10-28 08:48:45
Starting validation 2025-10-28 08:53:26


Metric val_loss improved by 0.128 >= min_delta = 0.0. New best score: 2.002


Ending validation 2025-10-28 08:55:18
Starting validation 2025-10-28 09:00:26


Metric val_loss improved by 0.140 >= min_delta = 0.0. New best score: 1.862


Ending validation 2025-10-28 09:02:36
Starting validation 2025-10-28 09:08:21


Metric val_loss improved by 0.131 >= min_delta = 0.0. New best score: 1.730


Ending validation 2025-10-28 09:10:18
Starting validation 2025-10-28 09:13:58
Ending validation 2025-10-28 09:15:14
Starting validation 2025-10-28 09:46:44


Metric val_loss improved by 0.099 >= min_delta = 0.0. New best score: 1.631


Ending validation 2025-10-28 09:48:10
Starting validation 2025-10-28 09:51:28


Metric val_loss improved by 0.080 >= min_delta = 0.0. New best score: 1.551


Ending validation 2025-10-28 09:52:53
Starting validation 2025-10-28 09:56:38


Metric val_loss improved by 0.066 >= min_delta = 0.0. New best score: 1.484


Ending validation 2025-10-28 09:58:09
Starting validation 2025-10-28 10:02:32


Metric val_loss improved by 0.045 >= min_delta = 0.0. New best score: 1.439


Ending validation 2025-10-28 10:04:15
Starting validation 2025-10-28 10:08:44
Ending validation 2025-10-28 10:10:21


`Trainer.fit` stopped: `max_epochs=2` reached.


In [11]:
model.save_preds_dict()

AttributeError: 'FrustrationCNN' object has no attribute 'test_dict'

In [13]:
import yaml

In [14]:
with open(f"./it3-5_loraAll_small/config.yaml", "w") as f:
    yaml.dump(config, f, default_flow_style=False)

In [15]:
#load the config
with open(f"./it3-5_loraAll_small/config.yaml", 'r') as f:
    config_loaded = yaml.safe_load(f)
config_loaded

{'batch_size': 64,
 'cath_sampling_n': 2,
 'ce_weighting': None,
 'dropout': 0.15,
 'experiment_name': 'it3-5_loraAll_small',
 'finetune': False,
 'hidden_dims': [64, 10],
 'input_dim': 1024,
 'lora_alpha': 1,
 'lora_modules': ['q', 'k'],
 'lora_r': 4,
 'max_seq_length': 512,
 'no_label_token': -100,
 'notes': '',
 'num_workers': 10,
 'output_dim': 4,
 'pLM_model': '../data/protT5',
 'parquet_path': '../data/frustration/v4_frustration.parquet.gzip',
 'precision': 'full',
 'prefix_prostT5': '<AA2fold>'}

In [20]:
1/0.1, 1/0.5, 1/0.4

(10.0, 2.0, 2.5)

In [22]:
10/14.5, 2/14.5, 2.5/14.5

(0.6896551724137931, 0.13793103448275862, 0.1724137931034483)

In [23]:
np.exp(0.8) / (np.exp(0.8) + np.exp(0.1) + np.exp(0.1))

np.float64(0.5017131981555416)