In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import gc
import copy

import torch
import pandas as pd
import numpy as np

import evaluate

from transformers import (
    T5Tokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)

from src.model_new import (
    T5EncoderModelForTokenClassification,
    create_datasets
)
import src.config
import src.data
import src.model_new


import peft
from peft import (
    LoraConfig,
    PeftModel
)

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
ROOT = '../'
torch.manual_seed(42)

<torch._C.Generator at 0x7f7b20fd9450>

---

In [4]:
t5_tokenizer = T5Tokenizer.from_pretrained(
        pretrained_model_name_or_path=src.config.base_model_name,
        do_lower_case=False,
        use_fast=True,
        legacy=False
    )

In [5]:
t5_base_model = T5EncoderModelForTokenClassification.from_pretrained(
    pretrained_model_name_or_path=src.config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    custom_num_labels=len(src.config.label_decoding),
    custom_dropout_rate=0.1,
    )

Some weights of T5EncoderModelForTokenClassification were not initialized from the model checkpoint at Rostlab/prot_t5_xl_uniref50 and are newly initialized: ['custom_classifier.bias', 'custom_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
FASTA_FILENAME = '5_SignalP_5.0_Training_set.fasta'
annotations_name = 'Label' # Chose Type or Label

df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + FASTA_FILENAME))

dataset_signalp = create_datasets(
    splits=src.config.splits,
    tokenizer=t5_tokenizer,
    data=df_data,
    annotations_name=annotations_name,
    dataset_size=src.config.dataset_size,
    encoder=src.config.select_encodings[annotations_name],
    )

del df_data

---

In [27]:
lora_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=['q', 'k', 'v', 'o'],
    bias="none",
)
t5_lora_model = peft.get_peft_model(t5_base_model, lora_config)
t5_lora_model.print_trainable_parameters()

trainable params: 3,932,160 || all params: 1,212,080,134 || trainable%: 0.3244141942186143


In [28]:
# seqeval_metric = evaluate.load("seqeval")
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")
# roc_auc_score_metric = evaluate.load("roc_auc", "multiclass")
# roc_auc_score = evaluate.load("roc_auc")
matthews_correlation_metric = evaluate.load("matthews_correlation")

In [29]:
# predictions = torch.tensor([[[1,0,0,0,1], [2,3,2,1,2], [2,2,5,1,3]],
#                             [[1,2,0,0,0], [1,float('nan'),1,0,100], [1,4,3,7,10]]])
# predictions_argmaxed = np.nan_to_num(predictions).argmax(axis=-1)
# predictions_argmaxed = predictions.nan_to_num().argmax(dim=-1)
# print(predictions_argmaxed)

# references = torch.tensor([[0,1,3], [1,4,3]])
# print(references)

# torch.Size([3, 71, 1024])
# print(predictions.shape)

In [30]:
# torch.tensor([19,  4,  5, 11,  6, 14, 19,  9,  5, 20,  9, 11,  7, 10, 21, 17,  7, 18,
#         18,  3, 10, 11, 16,  9,  3, 18,  7,  7,  6, 13,  6,  7, 17, 19, 17,  7,
#          5,  4,  5,  7, 19, 17,  7, 19, 17, 11, 18, 19, 11, 19, 17, 11, 19, 11,
#         11,  7,  5, 17, 19, 11, 13,  3,  7, 15, 17, 19,  7, 18,  3, 17,  1],
#        device='mps:0')
# results = roc_auc_score_metric.compute(references=references, prediction_scores=predictions[0], multi_class='ovr')
# print(round(results['roc_auc'], 2))

In [31]:
def batch_eval_flatten(predictions: np.ndarray, references: np.ndarray):
    results = {}
    predictions = np.nan_to_num(predictions).argmax(axis=-1)
    predictions = np.ndarray.flatten(predictions)
    references = np.ndarray.flatten(references)
    
    results.update(accuracy_metric.compute(predictions=predictions, references=references))
    results.update(precision_metric.compute(predictions=predictions, references=references, average='micro'))
    results.update(recall_metric.compute(predictions=predictions, references=references, average='micro'))
    results.update(f1_metric.compute(predictions=predictions, references=references, average='micro'))
    # results.update(roc_auc_score_metric.compute(prediction_scores=predictions, references=references, average='micro'))
    results.update(matthews_correlation_metric.compute(predictions=predictions, references=references, average='micro'))
    return results
# display(batch_eval_flatten(predictions.numpy(), references.numpy()))

def batch_eval_elementwise(predictions: np.ndarray, references: np.ndarray):
    results = {}
    # predictions = np.nan_to_num(predictions).argmax(axis=-1)
    predictions = predictions.argmax(axis=-1)
    
    results.update({'accuracy_metric': np.average([accuracy_metric.compute(predictions=x, references=y)['accuracy'] for x, y in zip(predictions, references)])})
    results.update({'precision_metric': np.average([precision_metric.compute(predictions=x, references=y, average='micro')['precision'] for x, y in zip(predictions, references)])})
    results.update({'recall_metric': np.average([recall_metric.compute(predictions=x, references=y, average='micro')['recall'] for x, y in zip(predictions, references)])})
    results.update({'f1_metric': np.average([f1_metric.compute(predictions=x, references=y, average='micro')['f1'] for x, y in zip(predictions, references)])})
    # results.update({'roc_auc': np.average([roc_auc_score_metric.compute(prediction_scores=x, references=y, average='micro')['roc_auc'] for x, y in zip(predictions, references)])})
    results.update({'matthews_correlation': np.average([matthews_correlation_metric.compute(predictions=x, references=y, average='micro')['matthews_correlation'] for x, y in zip(predictions, references)])})
    return results
# display(batch_eval_elementwise(predictions.numpy(), references.numpy()))

In [32]:
def compute_metrics(p):
    # print('=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= preds compute_metrics start =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=')
    predictions, references = p
    results = batch_eval_elementwise(predictions=predictions, references=references)
    # print(results)
    # print('=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= preds compute_metrics stop =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=')
    return results
# metrics = compute_metrics((predictions, references))

In [33]:
data_collator = DataCollatorForTokenClassification(tokenizer=t5_tokenizer)

training_args = TrainingArguments(
    output_dir='./checkpoints',
    learning_rate=src.config.lr,
    per_device_train_batch_size=src.config.batch_size,
    per_device_eval_batch_size=src.config.batch_size,
    num_train_epochs=src.config.num_epochs,
    logging_steps=src.config.logging_steps,
    # save_strategy="steps",
    # save_steps=src.config.save_steps,
    # evaluation_strategy="steps",
    # eval_steps=10,
    # gradient_accumulation_steps=accum,
    # load_best_model_at_end=True,
    # save_total_limit=5,
    seed=42,
    # fp16=True,
    # deepspeed=deepspeed_config,
    remove_unused_columns=False,
    label_names=['labels']
)

trainer = Trainer(
    model=t5_lora_model,
    args=training_args,
    train_dataset=dataset_signalp['train'],
    eval_dataset=dataset_signalp['valid'],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [34]:
gc.collect()
torch.cuda.empty_cache()
# torch.mps.empty_cache()

In [35]:
trainer.train()

Step,Training Loss
1,0.0
2,0.0
3,0.0
4,0.0
5,0.0
6,0.0
7,0.0
8,0.0
9,0.0
10,0.0


TrainOutput(global_step=779, training_loss=0.0, metrics={'train_runtime': 1041.8065, 'train_samples_per_second': 11.962, 'train_steps_per_second': 0.748, 'total_flos': 6434004287510856.0, 'train_loss': 0.0, 'epoch': 1.0})

In [41]:
metrics=trainer.evaluate()
print(metrics)

{'eval_loss': nan, 'eval_accuracy_metric': 0.6884808489403519, 'eval_precision_metric': 0.6884808489403519, 'eval_recall_metric': 0.6884808489403519, 'eval_f1_metric': 0.6884808489403519, 'eval_matthews_correlation': 0.0, 'eval_runtime': 237.3429, 'eval_samples_per_second': 17.481, 'eval_steps_per_second': 1.095, 'epoch': 1.0}


In [37]:
result_log = pd.DataFrame(trainer.state.log_history)
display(result_log)

Unnamed: 0,loss,learning_rate,epoch,step,train_runtime,train_samples_per_second,train_steps_per_second,total_flos,train_loss
0,0.0,0.000999,0.00,1,,,,,
1,0.0,0.000997,0.00,2,,,,,
2,0.0,0.000996,0.00,3,,,,,
3,0.0,0.000995,0.01,4,,,,,
4,0.0,0.000994,0.01,5,,,,,
...,...,...,...,...,...,...,...,...,...
775,0.0,0.000004,1.00,776,,,,,
776,0.0,0.000003,1.00,777,,,,,
777,0.0,0.000001,1.00,778,,,,,
778,0.0,0.000000,1.00,779,,,,,


In [39]:
dataset_signalp

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 12462
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 4149
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 4147
    })
})

---

In [40]:
adapter_location = '/models/linear_model_v4'
t5_lora_model.save_pretrained(ROOT + adapter_location)

---

In [42]:
def predict_model(sequence: str, tokenizer: T5Tokenizer, model: T5EncoderModelForTokenClassification):
    # print('sequence', sequence)
    tokenized_string = tokenizer.encode(sequence, padding=True, truncation=True, return_tensors="pt", max_length=1024)
    # print('tokenized_string', tokenized_string)
    with torch.no_grad():
        output = model(tokenized_string.to(device))
    # print('output', output)
    return output

def tranlate_logits(logits):
    return [src.config.label_decoding[x] for x in logits.argmax(-1).tolist()[0]]

In [43]:
_ds_index = 3211
_ds_type = 'test'

_inids_test = t5_tokenizer.decode(dataset_signalp[_ds_type][_ds_index]['input_ids'])
_labels_test = dataset_signalp[_ds_type][_ds_index]['labels']
_labels_test_decoded = [src.config.label_decoding[x] for x in _labels_test]
print(_inids_test)
print(_labels_test)
print(_labels_test_decoded)

M N V F F M F S L L F L A A L G S C A H D R N P L E E C F R E T D Y E E F L E I A K N G L T A T S N P K R V V I V G A G M A G L S A A Y V L</s>
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [44]:
preds = predict_model(_inids_test, t5_tokenizer, t5_lora_model)



In [45]:
_res = tranlate_logits(preds.logits.cpu().numpy())
print(_res)

['I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I']
