---
## Setup and Variables

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import re
import gc
import os
import math
import copy
import types
import yaml
import sys

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.nn import (
    CrossEntropyLoss,
    MSELoss
)
from torch.utils.data import DataLoader

import evaluate

import transformers
from transformers import (
    AutoModelForTokenClassification,
    AutoConfig,
    T5EncoderModel,
    T5Tokenizer,
    T5PreTrainedModel,
    T5ForConditionalGeneration,
    pipeline,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    set_seed,
    EvalPrediction,
    )
from transformers.modeling_outputs import TokenClassifierOutput

from peft import (
    LoraConfig,
    get_peft_model,
    TaskType,
    get_peft_config,
    PeftModel,
    PeftConfig,
    prepare_model_for_kbit_training
    )

import peft

from datasets import Dataset

import src.config
import src.data
import src.model_new

from src.model_working import (
    get_prottrans_tokenizer_model,
    df_to_dataset,
    inject_linear_layer,
    )
from src.utils import get_project_root_path
import random

In [3]:
base_model_name = src.config.base_model_name
print("Base Model:\t", base_model_name)
print("MPS:\t\t", torch.backends.mps.is_available())
ROOT = get_project_root_path()
print("Path:\t\t", ROOT)
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f"Using device:\t {device}")

torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

Base Model:	 Rostlab/prot_t5_xl_uniref50
MPS:		 False
Path:		 /home/ec2-user/developer/prottrans-t5-signalpeptide-prediction
Using device:	 cuda:0


---
## Create Tokenizer and Load Model

In [4]:
t5_tokenizer = T5Tokenizer.from_pretrained(
        pretrained_model_name_or_path=src.config.base_model_name,
        do_lower_case=False,
        use_fast=True,
        legacy=False
    )

In [5]:
t5_base_model = T5EncoderModel.from_pretrained(
        base_model_name,
        device_map='auto',
        load_in_8bit=False
    )

---
## Apply LoRA

In [6]:
lora_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=['q', 'k', 'v', 'o'],
    bias="none",
    modules_to_save=['custom_classifier'],
)

t5_lora_model = peft.get_peft_model(t5_base_model, lora_config)
t5_lora_model.print_trainable_parameters()

trainable params: 3,932,160 || all params: 1,212,073,984 || trainable%: 0.32441584027926795


In [7]:
t5_lora_model = inject_linear_layer(
    t5_lora_model=t5_lora_model,
    num_labels=len(src.config.label_decoding),
    dropout_rate=src.config.dropout_rate
    )

In [8]:
# [x for x in t5_lora_model.custom_classifier.named_parameters()]

In [9]:
t5_lora_model.encoder.block[4].layer[0].SelfAttention.v.lora_A.default.weight

Parameter containing:
tensor([[-0.0243,  0.0118, -0.0202,  ...,  0.0279, -0.0208, -0.0121],
        [-0.0100,  0.0219,  0.0305,  ...,  0.0182,  0.0274, -0.0185],
        [-0.0297, -0.0181, -0.0091,  ..., -0.0108, -0.0215,  0.0148],
        ...,
        [-0.0066,  0.0023, -0.0226,  ..., -0.0184,  0.0287,  0.0001],
        [ 0.0218,  0.0278,  0.0081,  ..., -0.0058,  0.0176, -0.0214],
        [ 0.0176,  0.0048, -0.0220,  ..., -0.0177,  0.0093, -0.0017]],
       device='cuda:0', requires_grad=True)

---
## Load Data, Split into Dataset, and Tokenize Sequences

In [10]:
FASTA_FILENAME = '5_SignalP_5.0_Training_set.fasta'
annotations_name = 'Label' # Choose Type or Label

df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + FASTA_FILENAME))

dataset_signalp = src.model_new.create_datasets(
    splits=src.config.splits,
    tokenizer=t5_tokenizer,
    data=df_data,
    annotations_name=annotations_name,
    dataset_size=src.config.dataset_size,
    encoder=src.config.select_encodings[annotations_name],
    )

del df_data

In [11]:
dataset_signalp

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 12462
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 4149
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 4147
    })
})

---
## Training Loop
https://huggingface.co/docs/peft/task_guides/token-classification-lora

In [12]:
# seqeval_metric = evaluate.load("seqeval")
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")
# roc_auc_score_metric = evaluate.load("roc_auc", "multiclass")
# roc_auc_score = evaluate.load("roc_auc")
matthews_correlation_metric = evaluate.load("matthews_correlation")

def batch_eval_elementwise(predictions: np.ndarray, references: np.ndarray):
    results = {}
    # predictions = np.nan_to_num(predictions).argmax(axis=-1)
    predictions = predictions.argmax(axis=-1)
    
    results.update({'accuracy_metric': np.average([accuracy_metric.compute(predictions=x, references=y)['accuracy'] for x, y in zip(predictions, references)])})
    results.update({'precision_metric': np.average([precision_metric.compute(predictions=x, references=y, average='micro')['precision'] for x, y in zip(predictions, references)])})
    results.update({'recall_metric': np.average([recall_metric.compute(predictions=x, references=y, average='micro')['recall'] for x, y in zip(predictions, references)])})
    results.update({'f1_metric': np.average([f1_metric.compute(predictions=x, references=y, average='micro')['f1'] for x, y in zip(predictions, references)])})
    # results.update({'roc_auc': np.average([roc_auc_score_metric.compute(prediction_scores=x, references=y, average='micro')['roc_auc'] for x, y in zip(predictions, references)])})
    results.update({'matthews_correlation': np.average([matthews_correlation_metric.compute(predictions=x, references=y, average='micro')['matthews_correlation'] for x, y in zip(predictions, references)])})
    return results
# display(batch_eval_elementwise(predictions.numpy(), references.numpy()))

def compute_metrics(p):
    # print('=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= preds compute_metrics start =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=')
    predictions, references = p
    results = batch_eval_elementwise(predictions=predictions, references=references)
    # print(results)
    # print('=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= preds compute_metrics stop =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=')
    return results
# metrics = compute_metrics((predictions, references))

In [13]:
data_collator = DataCollatorForTokenClassification(tokenizer=t5_tokenizer)

training_args = TrainingArguments(
    output_dir='./checkpoints',
    learning_rate=src.config.lr,
    per_device_train_batch_size=src.config.batch_size,
    per_device_eval_batch_size=src.config.batch_size,
    num_train_epochs=src.config.num_epochs,
    logging_steps=src.config.logging_steps,
    # save_strategy="steps",
    # save_steps=config.save_steps,
    # evaluation_strategy="steps",
    # eval_steps=src.config.eval_steps,
    # gradient_accumulation_steps=accum,
    # load_best_model_at_end=True,
    # save_total_limit=5,
    seed=42,
    # fp16=True,
    # deepspeed=deepspeed_config,
    remove_unused_columns=False,
    label_names=['labels'],
)

trainer = Trainer(
    model=t5_lora_model,
    args=training_args,
    train_dataset=dataset_signalp['train'],
    eval_dataset=dataset_signalp['valid'],
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [14]:
# print(next(t5_lora_model.parameters()).is_cuda)
# print(t5_lora_model.device)
# print(config.label_decoding)

In [15]:
gc.collect()
torch.cuda.empty_cache()
# torch.mps.empty_cache()

In [16]:
trainer.train()

Step,Training Loss,Validation Loss,Accuracy Metric,Precision Metric,Recall Metric,F1 Metric,Matthews Correlation
150,0.3123,0.312319,0.904834,0.904834,0.904834,0.904834,0.203015
300,0.0397,0.171633,0.927683,0.927683,0.927683,0.927683,0.236499
450,0.1097,0.141002,0.942606,0.942606,0.942606,0.942606,0.259352
600,0.0429,0.126532,0.950468,0.950468,0.950468,0.950468,0.266294
750,0.1267,0.121012,0.952217,0.952217,0.952217,0.952217,0.268881


TrainOutput(global_step=779, training_loss=0.3309959212332409, metrics={'train_runtime': 2479.3407, 'train_samples_per_second': 5.026, 'train_steps_per_second': 0.314, 'total_flos': 6433971638317056.0, 'train_loss': 0.3309959212332409, 'epoch': 1.0})

---

In [None]:
t5_lora_model.encoder.block[4].layer[0].SelfAttention.v.lora_A.default.weight

In [None]:
[x for x in t5_lora_model.custom_classifier.named_parameters()]

In [None]:
# t5_lora_model.encoder.block[4].layer[0].SelfAttention.v.lora_A.default.weight

In [None]:
# metrics=trainer.evaluate()
# print(metrics)

In [17]:
training_log = pd.DataFrame(trainer.state.log_history)
display(training_log)

Unnamed: 0,loss,learning_rate,epoch,step,eval_loss,eval_accuracy_metric,eval_precision_metric,eval_recall_metric,eval_f1_metric,eval_matthews_correlation,eval_runtime,eval_samples_per_second,eval_steps_per_second,train_runtime,train_samples_per_second,train_steps_per_second,total_flos,train_loss
0,1.8265,9.987163e-05,0.00,1,,,,,,,,,,,,,,
1,1.8314,9.974326e-05,0.00,2,,,,,,,,,,,,,,
2,1.8065,9.961489e-05,0.00,3,,,,,,,,,,,,,,
3,1.8209,9.948652e-05,0.01,4,,,,,,,,,,,,,,
4,1.7812,9.935815e-05,0.01,5,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
780,0.3085,3.851091e-07,1.00,776,,,,,,,,,,,,,,
781,0.0733,2.567394e-07,1.00,777,,,,,,,,,,,,,,
782,0.0172,1.283697e-07,1.00,778,,,,,,,,,,,,,,
783,0.0085,0.000000e+00,1.00,779,,,,,,,,,,,,,,


---

In [19]:
adapter_location = '/models/linear_model_v5'
t5_lora_model.save_pretrained(ROOT + adapter_location)
training_log.to_parquet(ROOT + adapter_location + '/training_log.parquet')

---

In [None]:
for name, param in t5_lora_model.base_model.named_parameters():
    if "lora" not in name:
        continue
    if param.isnan().any():
        print(f"New parameter {name:<13} | {param.numel():>5} parameters | not updated")
    else:
        print(f"New parameter {name:<13} | {param.numel():>5} parameters | updated")

---

In [20]:
def predict_model(sequence: str, tokenizer: T5Tokenizer, model):
    # print('sequence', sequence)
    tokenized_string = tokenizer.encode(sequence, padding=True, truncation=True, return_tensors="pt", max_length=1024)
    # print('tokenized_string', tokenized_string)
    with torch.no_grad():
        output = model(tokenized_string.to(device))
    # print('output', output)
    return output

def translate_logits(logits):
    return [src.config.label_decoding[x] for x in logits.argmax(-1).tolist()[0]]

In [21]:
_ds_index = 3290
_ds_type = 'test'

_inids_test = t5_tokenizer.decode(dataset_signalp[_ds_type][_ds_index]['input_ids'])
_labels_test = dataset_signalp[_ds_type][_ds_index]['labels']
_labels_test_decoded = [src.config.label_decoding[x] for x in _labels_test]
print(_inids_test)
print(_labels_test)
print(_labels_test_decoded)

M G H V N L P A S K R G N P R Q W R L L D I V T A A F F G I V L L F F I L L F T P L G D S M A A S G R Q T L L L S T A S D P R Q R Q R L V T</s>
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
['I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [22]:
preds = predict_model(_inids_test, t5_tokenizer, t5_lora_model)



In [23]:
_res = translate_logits(preds.logits.cpu().numpy())
print(_res)

['I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'M', 'I', 'M', 'O', 'I', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I']
