# ProtTans Finetuning with LoRA for Signal Peptide Prediction

## Links
### Papers/ Knowledge
- https://www.sciencedirect.com/science/article/pii/S2001037021000945
- https://huggingface.co/blog/peft
- https://ieeexplore.ieee.org/ielx7/34/9893033/9477085/supp1-3095381.pdf?arnumber=9477085
### Architecture
- https://www.philschmid.de/fine-tune-flan-t5-peft
- https://huggingface.co/spaces/evaluate-metric/seqeval
- https://huggingface.co/docs/transformers/v4.33.3/en/model_doc/esm#transformers.EsmForTokenClassification
- https://huggingface.co/docs/datasets/v2.14.5/en/package_reference/builder_classes#datasets.SplitGenerator
- https://huggingface.co/docs/datasets/v2.14.5/en/package_reference/main_classes#datasets.Dataset.add_column
- https://huggingface.co/docs/transformers/main_classes/data_collator
- https://huggingface.co/docs/transformers/main/en/main_classes/trainer#checkpoints
### Code
- https://github.com/ziegler-ingo/cleavage_extended/blob/master/models/final/c_bilstm_t5_coteaching.ipynb
- https://www.kaggle.com/code/henriupton/proteinet-pytorch-ems2-t5-protbert-embeddings/notebook#7.-Train-the-Model
- https://www.kaggle.com/code/prithvijaunjale/t5-multi-label-classification
### Optmization
- https://huggingface.co/blog/accelerate-large-models
- https://huggingface.co/docs/transformers/hpo_train

## ToDo
- Implement BitsAndBites (QLoRA)
- Implement DeepSpeed
- Fix weird extra char on inference

## Notebook Setup
___

In [1]:
%load_ext autoreload
%autoreload 2

## Packages
___

In [3]:
import re
import os
import math
import copy
import types
import yaml

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.nn import (
    CrossEntropyLoss,
    MSELoss
)
from torch.utils.data import DataLoader

import evaluate

from transformers import (
    AutoModelForTokenClassification,
    AutoConfig,
    T5EncoderModel,
    T5Tokenizer,
    T5PreTrainedModel,
    T5ForConditionalGeneration,
    pipeline,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    set_seed,
    )
from transformers.modeling_outputs import TokenClassifierOutput

from peft import (
    LoraConfig,
    get_peft_model,
    TaskType,
    get_peft_config,
    PeftModel,
    PeftConfig,
    prepare_model_for_kbit_training
    )

from datasets import Dataset

import src.config as config

from src.model import (
    get_prottrans_tokenizer_model,
    df_to_dataset,
    inject_linear_layer,
    compute_metrics_fast
    )
from src.utils import get_project_root_path

---
## Setup and Variables

In [4]:
base_model_name = config.base_model_name
print("Base Model:\t", base_model_name)
print("MPS:\t\t", torch.backends.mps.is_available())
ROOT = get_project_root_path()
print("Path:\t\t", ROOT)
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f"Using device:\t {device}")

Base Model:	 Rostlab/prot_t5_xl_uniref50
MPS:		 True
Path:		 /Users/finnlueth/Developer/gits/prottrans-t5-signalpeptide-prediction
Using device:	 mps


---
## Create Tokenizer and Load Model

In [5]:
# model_architecture = T5ForConditionalGeneration
model_architecture = T5EncoderModel
t5_tokenizer, t5_base_model = get_prottrans_tokenizer_model(base_model_name, model_architecture)

---
## Load Data, Split into Dataset, and Tokenize Sequences

In [6]:
df_data = pd.read_parquet(ROOT + '/data/processed/5.0_train.parquet.gzip')

In [7]:
df_data.head()

Unnamed: 0,Sequence,Label,Split
0,M A P T L F Q K L F S K R T G L G A P G R D A ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0
1,M D F T S L E T T T F E E V V I A L G S N V G ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",1
2,M D D I S G R Q T L P R I N R L L E H V G N P ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",1
3,M L G T V K M E G H E T S D W N S Y Y A D T Q ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4
4,M L G A V K M E G H E P S D W S S Y Y A E P E ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4


In [17]:
ds_train = df_data[df_data.Split.isin([0, 1, 2])].head(config.dataset_size*3)
ds_train = df_to_dataset(
    t5_tokenizer,
    ds_train.Sequence.to_list(),
    ds_train.Label.to_list(),
)

ds_validate = df_data[df_data.Split.isin([3])].head(config.dataset_size)
ds_validate = df_to_dataset(
    t5_tokenizer,
    ds_validate.Sequence.to_list(),
    ds_validate.Label.to_list(),
)

ds_test = df_data[df_data.Split.isin([4])].head(config.dataset_size)
ds_test = df_to_dataset(
    t5_tokenizer,
    ds_test.Sequence.to_list(),
    ds_test.Label.to_list()
)

In [18]:
print(ds_train)
print(ds_validate)
print(ds_test)

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 6
})
Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 2
})
Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 2
})


In [19]:
print(ds_test[0]['input_ids'])
print(len(ds_test[0]['input_ids']))
print(ds_test[0]['attention_mask'])
print(len(ds_test[0]['attention_mask']))
print(ds_test[0]['labels'])
print(len(ds_test[0]['labels']))

[19, 4, 5, 11, 6, 14, 19, 9, 5, 20, 9, 11, 7, 10, 21, 17, 7, 18, 18, 3, 10, 11, 16, 9, 3, 18, 7, 7, 6, 13, 6, 7, 17, 19, 17, 7, 5, 4, 5, 7, 19, 17, 7, 19, 17, 11, 18, 19, 11, 19, 17, 11, 19, 11, 11, 7, 5, 17, 19, 11, 13, 3, 7, 15, 17, 19, 7, 18, 3, 17, 1]
71
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
71
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
70


In [20]:
t5_tokenizer.decode(ds_test[0]['input_ids'])

'M L G T V K M E G H E T S D W N S Y Y A D T Q E A Y S S V P V S N M N S G L G S M N S M N T Y M T M N T M T T S G N M T P A S F N M S Y A N</s>'

In [21]:
t5_tokenizer.decode(range(0, 28))

'<pad></s><unk>A L G V S R E D T I P K F Q N Y M H W C X B O U Z'

---
## Apply LoRA

In [22]:
lora_config = LoraConfig(
        # task_type=TaskType.TOKEN_CLS,
        inference_mode=False,
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=['q', 'k', 'v', 'o'],
        # target_modules=['o'],
        bias="none",
    )

In [23]:
t5_lora_model = get_peft_model(t5_base_model, lora_config)
del t5_base_model
# t5_lora_model = prepare_model_for_kbit_training(t5_lora_model) # add quantization

In [24]:
# t5_lora_model = t5_base_model

In [25]:
t5_lora_model.print_trainable_parameters()

trainable params: 3,932,160 || all params: 1,212,073,984 || trainable%: 0.32441584027926795


---
## DeepSpeed

In [26]:
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "9994"  # modify if RuntimeError: Address already in use
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"

---
## Training Loop
https://huggingface.co/docs/peft/task_guides/token-classification-lora

In [27]:
data_collator = DataCollatorForTokenClassification(tokenizer=t5_tokenizer)

In [28]:
with open(ROOT+'/deepspeed_config.yaml', 'r') as file:
    deepspeed_config = yaml.safe_load(file)

In [86]:
training_args = TrainingArguments(
    output_dir='./',
    learning_rate=config.lr,
    per_device_train_batch_size=config.batch_size,
    per_device_eval_batch_size=config.batch_size,
    num_train_epochs=config.num_epochs,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=config.save_steps,
    logging_steps=config.logging_steps,
    load_best_model_at_end=True,
    save_total_limit=5,
    seed=42,
    # deepspeed=deepspeed_config
)

trainer = Trainer(
    model=t5_lora_model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_validate,
    data_collator=data_collator,
    compute_metrics=config.metric
)

In [87]:
# t5_lora_model = t5_lora_model.to('cpu')

In [88]:
print(next(t5_lora_model.parameters()).is_cuda)
print(t5_lora_model.device)

False
mps:0


In [89]:
config.label_decoding

{0: 'I', 1: 'L', 2: 'M', 3: 'O', 4: 'S', 5: 'T'}

In [90]:
t5_lora_model = inject_linear_layer(
    t5_lora_model=t5_lora_model,
    num_labels=config.label_decoding.__len__(),
    dropout_rate=config.dropout_rate
    )

In [91]:
trainer.train()

  0%|          | 0/1 [00:00<?, ?it/s]

forward...
{'loss': 1.6068, 'learning_rate': 0.0, 'epoch': 1.0}


ValueError: Trainer: evaluation requires an eval_dataset.

---
## Results

In [81]:
pd.DataFrame(trainer.state.log_history)

Unnamed: 0,loss,learning_rate,epoch,step,eval_runtime,eval_samples_per_second,eval_steps_per_second
0,1.6952,0.0,1.0,1,,,
1,,,1.0,1,75.1006,0.027,0.013


---
## Save Trainer

Q02742|EUKARYA|NO_SP|4\
MLRTLLRRRLFSYPTKYYFMVLVLSLITFSVLRIHQKPEFVSVRHLELAGENPSSDINCTKVLQGDVNEI\
IIIIIIIIIMMMMMMMMMMMMMMMMMMMMMMMOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO

Q9NQR9|EUKARYA|NO_SP|4\
MDFLHRNGVLIIQHLQKDYRAYYTFLNFMSNVGDPRNIFFIYFPLCFQFNQTVGTKMIWVAVIGDWLNLI\
OOOOOOOOOOOOOOOOOOOOOOOOMMMMMMMMMMMMMMMMMMMMMIIIIIIIIIIIMMMMMMMMMMMMMM

In [82]:
print(*ds_test[1])

input_ids attention_mask labels


In [83]:
t5_lora_model.to(device)

test_set = ds_test.select(range(0,2)).with_format("torch", device=device)

# For token classification we need a data collator here to pad correctly
data_collator = DataCollatorForTokenClassification(t5_tokenizer) 

# Create a dataloader for the test dataset
test_dataloader = DataLoader(test_set, batch_size=16, shuffle = False, collate_fn = data_collator)

# Put the model in evaluation mode
t5_lora_model.eval()

# Make predictions on the test dataset
predictions = []
# We need to collect the batch["labels"] as well, this allows us to filter out all positions with a -100 afterwards
padded_labels = []

counter = 0

In [84]:
with torch.no_grad():
    for batch in test_dataloader:
        # print(counter)
        counter += 1
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # Padded labels from the data collator
        padded_labels += batch['labels'].tolist()
        # Add batch results(logits) to predictions, we take the argmax here to get the predicted class
        prediction = t5_lora_model(input_ids=input_ids).logits.argmax(dim=-1).tolist()
        print(prediction)
        predictions += prediction#.argmax(dim=-1).tolist()

forward...
[[4, 5, 3, 0, 2, 5, 1, 0, 1, 0, 2, 1, 4, 3, 3, 0, 2, 3, 3, 4, 0, 0, 0, 0, 2, 0, 0, 0, 5, 3, 2, 0, 0, 3, 1, 1, 1, 1, 1, 2, 3, 3, 2, 3, 3, 2, 3, 5, 3, 3, 0, 0, 4, 4, 0, 0, 0, 0, 0, 0, 2, 4, 2, 0, 0, 0, 2, 0, 0, 0, 2], [0, 5, 1, 0, 2, 5, 1, 0, 1, 0, 0, 1, 4, 3, 5, 4, 2, 3, 3, 0, 0, 5, 0, 0, 0, 4, 4, 2, 3, 0, 3, 0, 1, 1, 1, 3, 0, 0, 3, 3, 3, 2, 0, 4, 4, 4, 4, 4, 4, 4, 0, 1, 1, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2]]


In [85]:
index_item = 1

actual = [config.label_decoding[x] for x in test_set['labels'][index_item].tolist()]
print(actual.__len__())
print(*actual)

pred = [config.label_decoding[x] for x in predictions[index_item]]
print(pred.__len__())
print(*pred)

70
I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I
71
I T L I M T L I L I I L S O T S M O O I I T I I I S S M O I O I L L L O I I O O O M I S S S S S S S I L L I I I I O I S I I I I I I M I I I M


---
## Save Model

In [None]:
t5_lora_model.save_pretrained(ROOT + '/models/linear_model_v2')

---
---
---