In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import gc
import copy

import torch
import pandas as pd
import numpy as np

import evaluate

from transformers import (
    T5Tokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)

from src.model_new import (
    T5EncoderModelForTokenClassification,
    create_datasets
)
import src.config
import src.data
import src.model_new


import peft
from peft import (
    LoraConfig,
    PeftModel
)

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
ROOT = '../'

In [4]:
t5_tokenizer = T5Tokenizer.from_pretrained(
        pretrained_model_name_or_path=src.config.base_model_name,
        do_lower_case=False,
        use_fast=True,
        legacy=False
    )

In [5]:
t5_base_model = T5EncoderModelForTokenClassification.from_pretrained(
    pretrained_model_name_or_path=src.config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    custom_num_labels=len(src.config.label_decoding),
    custom_dropout_rate=0.1,
    )

Some weights of T5EncoderModelForTokenClassification were not initialized from the model checkpoint at Rostlab/prot_t5_xl_uniref50 and are newly initialized: ['custom_classifier.weight', 'custom_classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
FASTA_FILENAME = '5_SignalP_5.0_Training_set.fasta'
annotations_name = 'Label' # Chose Type or Label

df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + FASTA_FILENAME))

dataset_signalp = create_datasets(
    splits=src.config.splits,
    tokenizer=t5_tokenizer,
    data=df_data,
    annotations_name=annotations_name,
    dataset_size=src.config.dataset_size,
    encoder=src.config.select_encodings[annotations_name],
    )

del df_data

In [7]:
lora_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=['q', 'k', 'v', 'o'],
    bias="none",
)
t5_lora_model = peft.get_peft_model(t5_base_model, lora_config)
t5_lora_model.print_trainable_parameters()

trainable params: 3,932,160 || all params: 1,212,080,134 || trainable%: 0.3244141942186143


In [8]:
# seqeval_metric = evaluate.load("seqeval")
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")
# roc_auc_score_metric = evaluate.load("roc_auc", "multiclass")
# roc_auc_score = evaluate.load("roc_auc")
matthews_correlation_metric = evaluate.load("matthews_correlation")

In [9]:
# predictions = torch.tensor([[[1,0,0,0,1], [2,3,2,1,2], [2,2,5,1,3]],
#                             [[1,2,0,0,0], [1,float('nan'),1,0,100], [1,4,3,7,10]]])
# predictions_argmaxed = np.nan_to_num(predictions).argmax(axis=-1)
# predictions_argmaxed = predictions.nan_to_num().argmax(dim=-1)
# print(predictions_argmaxed)

# references = torch.tensor([[0,1,3], [1,4,3]])
# print(references)

# torch.Size([3, 71, 1024])
# print(predictions.shape)

In [10]:
# torch.tensor([19,  4,  5, 11,  6, 14, 19,  9,  5, 20,  9, 11,  7, 10, 21, 17,  7, 18,
#         18,  3, 10, 11, 16,  9,  3, 18,  7,  7,  6, 13,  6,  7, 17, 19, 17,  7,
#          5,  4,  5,  7, 19, 17,  7, 19, 17, 11, 18, 19, 11, 19, 17, 11, 19, 11,
#         11,  7,  5, 17, 19, 11, 13,  3,  7, 15, 17, 19,  7, 18,  3, 17,  1],
#        device='mps:0')
# results = roc_auc_score_metric.compute(references=references, prediction_scores=predictions[0], multi_class='ovr')
# print(round(results['roc_auc'], 2))

In [11]:
def batch_eval_flatten(predictions: np.ndarray, references: np.ndarray):
    results = {}
    predictions = np.nan_to_num(predictions).argmax(axis=-1)
    predictions = np.ndarray.flatten(predictions)
    references = np.ndarray.flatten(references)
    
    results.update(accuracy_metric.compute(predictions=predictions, references=references))
    results.update(precision_metric.compute(predictions=predictions, references=references, average='micro'))
    results.update(recall_metric.compute(predictions=predictions, references=references, average='micro'))
    results.update(f1_metric.compute(predictions=predictions, references=references, average='micro'))
    # results.update(roc_auc_score_metric.compute(prediction_scores=predictions, references=references, average='micro'))
    results.update(matthews_correlation_metric.compute(predictions=predictions, references=references, average='micro'))
    return results
# display(batch_eval_flatten(predictions.numpy(), references.numpy()))

def batch_eval_elementwise(predictions: np.ndarray, references: np.ndarray):
    results = {}
    # predictions = np.nan_to_num(predictions).argmax(axis=-1)
    predictions = predictions.argmax(axis=-1)
    
    results.update({'accuracy_metric': np.average([accuracy_metric.compute(predictions=x, references=y)['accuracy'] for x, y in zip(predictions, references)])})
    results.update({'precision_metric': np.average([precision_metric.compute(predictions=x, references=y, average='micro')['precision'] for x, y in zip(predictions, references)])})
    results.update({'recall_metric': np.average([recall_metric.compute(predictions=x, references=y, average='micro')['recall'] for x, y in zip(predictions, references)])})
    results.update({'f1_metric': np.average([f1_metric.compute(predictions=x, references=y, average='micro')['f1'] for x, y in zip(predictions, references)])})
    # results.update({'roc_auc': np.average([roc_auc_score_metric.compute(prediction_scores=x, references=y, average='micro')['roc_auc'] for x, y in zip(predictions, references)])})
    results.update({'matthews_correlation': np.average([matthews_correlation_metric.compute(predictions=x, references=y, average='micro')['matthews_correlation'] for x, y in zip(predictions, references)])})
    return results
# display(batch_eval_elementwise(predictions.numpy(), references.numpy()))

In [12]:
def compute_metrics(p):
    # print('=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= preds compute_metrics start =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=')
    predictions, references = p
    results = batch_eval_elementwise(predictions=predictions, references=references)
    # print(results)
    # print('=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= preds compute_metrics stop =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=')
    return results
# metrics = compute_metrics((predictions, references))

In [13]:
data_collator = DataCollatorForTokenClassification(tokenizer=t5_tokenizer)

training_args = TrainingArguments(
    output_dir='./checkpoints',
    learning_rate=src.config.lr,
    per_device_train_batch_size=src.config.batch_size,
    per_device_eval_batch_size=src.config.batch_size,
    num_train_epochs=src.config.num_epochs,
    logging_steps=src.config.logging_steps,
    # save_strategy="steps",
    # save_steps=src.config.save_steps,
    evaluation_strategy="steps",
    eval_steps=10,
    # gradient_accumulation_steps=accum,
    # load_best_model_at_end=True,
    # save_total_limit=5,
    seed=42,
    # fp16=True,
    # deepspeed=deepspeed_config,
    remove_unused_columns=False,
    label_names=['labels']
)

trainer = Trainer(
    model=t5_lora_model,
    args=training_args,
    train_dataset=dataset_signalp['train'],
    eval_dataset=dataset_signalp['valid'],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [14]:
gc.collect()
torch.cuda.empty_cache()
# torch.mps.empty_cache()

In [15]:
trainer.train()

Step,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
# metrics=trainer.evaluate()
# print(metrics)

In [None]:
result_log = pd.DataFrame(trainer.state.log_history)
display(result_log)

---

In [None]:
dataset_signalp['test'][0]['input_ids']

In [None]:
def predict_model(sequence: str, tokenizer: T5Tokenizer, model: T5EncoderModelForTokenClassification):
    tokenized_string = tokenizer.encode(sequence, padding=True, truncation=True, return_tensors="pt", max_length=1024)
    print(tokenized_string)
    with torch.no_grad():
        output = model(tokenized_string.to(device))
    print(output)

In [None]:
test_seq = 'M K N W L L L S V P L L L L L G S S S'

In [None]:
predict_model(dataset_signalp['test'][0]['input_ids'], t5_tokenizer, t5_lora_model)

---

In [None]:
adapter_location = '/models/testing'
t5_lora_model.save_pretrained(ROOT + adapter_location)

---

In [None]:
t5_base_model_reloaded = T5EncoderModelForTokenClassification.from_pretrained(
    pretrained_model_name_or_path=src.config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    custom_num_labels=len(src.config.label_decoding),
    custom_dropout_rate=0.1,
    )

In [None]:
t5_base_model_reloaded_original = copy.deepcopy(t5_base_model_reloaded)

In [None]:
t5_lora_model_reloaded = PeftModel.from_pretrained(
    model = t5_base_model_reloaded,
    is_trainable=False,
    model_id=ROOT+adapter_location,
    custom_num_labels=len(src.config.label_decoding),
    custom_dropout_rate=0.1,
)

In [None]:
for index, (param1, param2) in enumerate(zip(t5_base_model_reloaded_original.parameters(), t5_base_model_reloaded.parameters())):
    if not torch.equal(param1.data, param2.data):
        print(f"Models have different weights on layer {index}")
        print(param1.data)
        print(param2.data)
        break
else:
    print("Models have identical weights")


In [None]:
for index, (param1, param2) in enumerate(zip(t5_lora_model_reloaded.parameters(), t5_lora_model.parameters())):
    if not torch.equal(param1.data, param2.data):
        print(f"Models have different weights on layer {index}")
        break
else:
    print("Models have identical weights")


In [None]:
torch.set_printoptions(profile="default")

In [None]:
z = [x for x in t5_base_model_reloaded.parameters()]
a = [x for x in t5_lora_model_reloaded.parameters()]
b = [x for x in t5_lora_model.parameters()]

In [None]:
print(len(z)) # base reload
print(len(a)) # lora reload
print(len(b)) # lora

In [None]:
curr_index = 1

In [None]:
a[curr_index].shape

In [None]:
torch.equal(a[curr_index], b[curr_index])

In [None]:
print(z[curr_index])
print(a[curr_index])
print(b[curr_index])

In [None]:
print(sum(sum(z[curr_index])))
print(sum(sum(a[curr_index])))
print(sum(sum(b[curr_index])))

In [None]:
# ds_test[0]

In [None]:
defaul_reloaded = [x for x in t5_base_model_reloaded.parameters()][195]
defaul_reloaded_og = [x for x in t5_base_model_reloaded_original.parameters()][195]

In [None]:
for index, (x, y) in enumerate(zip(defaul_reloaded, defaul_reloaded_og)):
    if not torch.equal(x, y):
        print(index)
        print(x)
        print(y)
        print('-------------------')

In [None]:
print(*defaul_reloaded[2].tolist())
print(*defaul_reloaded_og[2].tolist())

In [None]:
(defaul_reloaded_og[2].tolist()[9])

In [None]:
with torch.no_grad():
    embds_1 = t5_base_model_reloaded.encoder(
        input_ids=torch.tensor([[7, 7, 7, 7, 7]]).to('mps'),
        attention_mask=torch.tensor([[1, 1, 1, 1, 1]]).to('mps')
    )

In [None]:
embds_2 = t5_base_model_reloaded_original.forward(
    input_ids=torch.tensor([[7, 4, 7, 11, 7]]).to('mps'),
    attention_mask=torch.tensor([[1, 1, 1, 1, 1]]).to('mps')
)

In [None]:
embds_1.last_hidden_state