---
## Setup and Variables

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import gc
import copy
import random

import torch
import torch.nn as nn
import pandas as pd
import numpy as np

import seaborn as sns

import evaluate

from transformers import (
    T5Tokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    TrainerCallback
)

from src.model_new import (
    T5EncoderModelForTokenClassification
)

import src.config
import src.data
import src.model_new
import src.utils


import peft
from peft import (
    LoraConfig,
)


import sklearn.metrics

  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


In [3]:
print("Base Model:\t", src.config.base_model_name)
print("MPS:\t\t", torch.backends.mps.is_available())
ROOT = src.utils.get_project_root_path()
print("Path:\t\t", ROOT)
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f"Using device:\t {device}")

torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

Base Model:	 Rostlab/prot_t5_xl_uniref50
MPS:		 True
Path:		 /Users/finnlueth/Developer/gits/prottrans-t5-signalpeptide-prediction
Using device:	 mps


---
## Create Tokenizer and Load Model

In [4]:
t5_tokenizer = T5Tokenizer.from_pretrained(
    pretrained_model_name_or_path=src.config.base_model_name,
    do_lower_case=False,
    use_fast=True,
    legacy=False
)

In [5]:
t5_base_model = T5EncoderModelForTokenClassification.from_pretrained(
    pretrained_model_name_or_path=src.config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    custom_num_labels=len(src.config.label_decoding),
    custom_dropout_rate=0.1,
    use_crf=True
    )

Some weights of T5EncoderModelForTokenClassification were not initialized from the model checkpoint at Rostlab/prot_t5_xl_uniref50 and are newly initialized: ['custom_classifier.bias', 'custom_classifier.weight', 'crf.start_transitions', 'crf.end_transitions', 'crf.transitions']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
t5_base_model.custom_classifier.weight = nn.Linear(
    in_features=t5_base_model.config.hidden_size,
    out_features=t5_base_model.custom_num_labels).weight
t5_base_model.crf.reset_parameters()

---
## Apply LoRA

In [7]:
lora_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=['q', 'k', 'v', 'o'],
    bias="none",
    modules_to_save=['custom_classifier', 'crf'],
)

t5_lora_model = peft.get_peft_model(t5_base_model, lora_config)
t5_lora_model.print_trainable_parameters()

trainable params: 3,944,556 || all params: 1,212,086,380 || trainable%: 0.3254352218692532


In [8]:
# t5_lora_model.base_model.custom_classifier.modules_to_save.default.weight

In [9]:
# [x[0] for x in t5_base_model.crf.named_parameters()]
# t5_lora_model.crf.transitions

In [10]:
# [x for x in t5_lora_model.named_parameters() if 'crf' in x[0]]

---
## Load Data, Split into Dataset, and Tokenize Sequences

In [11]:
FASTA_FILENAME = '5_SignalP_5.0_Training_set.fasta'
# FASTA_FILENAME = '5_SignalP_5.0_Training_set_testing.fasta'
annotations_name = 'Label' # Choose Type or Label

df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + FASTA_FILENAME))

dataset_signalp = src.model_new.create_datasets(
    splits=src.config.splits,
    tokenizer=t5_tokenizer,
    data=df_data,
    annotations_name=annotations_name,
    # dataset_size=src.config.dataset_size,
    dataset_size=3,
    encoder=src.config.select_encodings[annotations_name],
    )

del df_data

  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)


In [26]:
display(dataset_signalp)

ds_index = 0
print(dataset_signalp['valid'][ds_index]['input_ids'])
print(dataset_signalp['valid'][ds_index]['labels'])
print(dataset_signalp['valid'][ds_index]['attention_mask'])

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 9
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 3
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 3
    })
})

[19, 9, 13, 12, 10, 10, 12, 4, 15, 9, 6, 11, 10, 3, 15, 14, 11, 16, 14, 9, 10, 4, 4, 9, 4, 6, 11, 4, 12, 10, 12, 18, 5, 9, 16, 6, 17, 16, 9, 5, 7, 18, 9, 9, 14, 11, 8, 15, 12, 9, 11, 4, 17, 11, 4, 4, 9, 10, 17, 13, 7, 11, 11, 5, 9, 12, 5, 21, 10, 4, 1]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


---

In [13]:
import torch
from torchcrf import CRF
num_tags = 5
model = CRF(num_tags=num_tags, batch_first=True)

In [14]:
batch_size = 2
seq_length = 4
emissions = torch.randn(batch_size, seq_length, num_tags)
tags = torch.tensor([
    [1, 2, 3, 3],
    [2, 2, 2, 3]
    ], dtype=torch.long)  # (seq_length, batch_size)

display(emissions, emissions.shape)
display(tags, tags.shape)
display(model(emissions, tags))
display(torch.Tensor(model.decode(emissions)).shape)
display(model.decode(emissions))

tensor([[[-0.9631,  1.0844, -1.5489,  2.1360,  0.0671],
         [ 0.7196, -0.4391, -1.2591, -2.4885, -0.5998],
         [ 0.3651, -2.0420,  0.7909,  0.2740,  1.1014],
         [ 0.2062,  0.2345, -0.7621, -0.4025,  1.3991]],

        [[ 1.0017, -0.4141, -0.2636, -2.1617,  0.5561],
         [ 0.1655,  0.3286, -0.6225, -0.5243, -0.1821],
         [-0.1025, -0.4402, -1.6304,  0.0114, -0.0502],
         [ 2.1811,  0.5598,  0.8657,  0.4453,  0.1986]]])

torch.Size([2, 4, 5])

tensor([[1, 2, 3, 3],
        [2, 2, 2, 3]])

torch.Size([2, 4])

  score = torch.where(mask[i].unsqueeze(1), next_score, score)


tensor(-17.9651, grad_fn=<SumBackward0>)

torch.Size([2, 4])

[[3, 0, 4, 4], [0, 1, 3, 0]]

---
## Training Loop
https://huggingface.co/docs/peft/task_guides/token-classification-lora

In [15]:
data_collator = DataCollatorForTokenClassification(tokenizer=t5_tokenizer)

training_args = TrainingArguments(
    output_dir='./checkpoints',
    learning_rate=src.config.lr,
    per_device_train_batch_size=src.config.batch_size,
    per_device_eval_batch_size=src.config.batch_size,
    num_train_epochs=src.config.num_epochs,
    logging_steps=src.config.logging_steps,
    # save_strategy="steps",
    # save_steps=src.config.save_steps,
    # evaluation_strategy="steps",
    # eval_steps=src.config.eval_steps,
    # gradient_accumulation_steps=accum,
    # load_best_model_at_end=True,
    # save_total_limit=5,
    seed=42,
    # fp16=True,
    # deepspeed=deepspeed_config,
    remove_unused_columns=False,
    label_names=['labels'],
    # debug="underflow_overflow",
)

trainer = Trainer(
    model=t5_lora_model,
    args=training_args,
    train_dataset=dataset_signalp['train'],
    eval_dataset=dataset_signalp['valid'],
    data_collator=data_collator,
    compute_metrics=src.model_new.compute_metrics,
)

# class EvaluateFirstStepCallback(TrainerCallback):
#     def on_step_begin(self, args, state, control, **kwargs):
#         if state.global_step == 0:
#             control.should_evaluate = True
# trainer.add_callback(EvaluateFirstStepCallback())

In [16]:
gc.collect()
torch.cuda.empty_cache()
# torch.mps.empty_cache()

In [17]:
trainer.train()

  0%|          | 0/1 [00:00<?, ?it/s]

loc("mps_select"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/75428952-3aa4-11ee-8b65-46d450270006/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":294:0)): error: 'anec.gain_offset_control' op result #0 must be 4D/5D memref of 16-bit float or 8-bit signed integer or 8-bit unsigned integer values, but got 'memref<1x9x1x1xi1>'
  score = torch.where(mask[i].unsqueeze(1), next_score, score)
  seq_ends = mask.long().sum(dim=0) - 1


decoded_tags [[1, 5, 2, 2, 2, 2, 2, 2, 2, 4, 3, 3, 2, 2, 4, 4, 4, 3, 2, 4, 3, 2, 4, 4, 3, 2, 4, 4, 5, 5, 0, 0, 3, 4, 0, 4, 0, 5, 4, 2, 2, 0, 4, 0, 0, 3, 2, 4, 3, 3, 3, 3, 3, 5, 1, 0, 3, 3, 2, 4, 1, 4, 3, 2, 4, 3, 2, 0, 3, 3, 2], [5, 4, 2, 2, 1, 5, 5, 3, 3, 2, 2, 2, 4, 4, 3, 3, 3, 3, 3, 3, 5, 5, 1, 5, 4, 3, 3, 3, 3, 3, 3, 3, 2, 4, 4, 4, 4, 4, 4, 3, 2, 2, 4, 4, 2, 4, 4, 3, 3, 3, 3, 2, 2, 0, 4, 3, 3, 2, 4, 0, 2, 2, 4, 0, 3, 2, 0, 5, 2, 2, 2], [5, 2, 2, 2, 5, 4, 1, 5, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 5, 2, 2, 2, 4, 5, 2, 4, 4, 4, 1, 5, 2, 4, 4, 2, 2, 2, 2, 2, 2, 0, 0, 5, 1, 5, 4, 0, 2, 4, 1, 5, 4, 3, 2, 0, 3, 2, 4, 0, 4, 4, 3, 3, 2, 4, 4, 0, 2, 0, 3, 2], [5, 4, 4, 4, 4, 0, 2, 4, 3, 2, 2, 2, 2, 4, 1, 1, 5, 2, 1, 5, 4, 1, 5, 5, 0, 4, 3, 3, 3, 4, 4, 4, 3, 5, 4, 4, 3, 3, 3, 2, 2, 4, 0, 4, 5, 4, 3, 2, 4, 0, 4, 3, 5, 0, 0, 3, 3, 3, 3, 2, 5, 5, 4, 3, 3, 3, 2, 4, 0, 3, 2], [5, 2, 4, 0, 4, 4, 4, 4, 4, 3, 3, 2, 4, 4, 4, 4, 1, 5, 1, 1, 5, 4, 2, 4, 3, 5, 0, 2, 2, 4, 3, 3, 3, 5, 4, 1, 2, 4, 1, 5, 5, 1,

loc("mps_select"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/75428952-3aa4-11ee-8b65-46d450270006/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":294:0)): error: 'anec.gain_offset_control' op result #0 must be 4D/5D memref of 16-bit float or 8-bit signed integer or 8-bit unsigned integer values, but got 'memref<1x9x1x1xi1>'
loc("mps_select"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/75428952-3aa4-11ee-8b65-46d450270006/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":294:0)): error: 'anec.gain_offset_control' op result #0 must be 4D/5D memref of 16-bit float or 8-bit signed integer or 8-bit unsigned integer values, but got 'memref<1x9x1x1xi1>'


{'loss': 1133.2013, 'learning_rate': 0.0, 'epoch': 1.0}
{'train_runtime': 83.5951, 'train_samples_per_second': 0.108, 'train_steps_per_second': 0.012, 'train_loss': 1133.2012939453125, 'epoch': 1.0}


TrainOutput(global_step=1, training_loss=1133.2012939453125, metrics={'train_runtime': 83.5951, 'train_samples_per_second': 0.108, 'train_steps_per_second': 0.012, 'train_loss': 1133.2012939453125, 'epoch': 1.0})

In [29]:
metrics=trainer.evaluate()
print(metrics)

  0%|          | 0/1 [00:00<?, ?it/s]

predictions [[[-0.18970905  0.18024121  0.06065112 -0.11171746  0.16198713
   -0.02602383]
  [-0.05944495  0.09248458 -0.17799744 -0.15251164 -0.10702872
    0.13930583]
  [-0.29510376  0.20085649  0.25969154  0.10611022  0.35970443
   -0.13364089]
  ...
  [ 0.31057712 -0.00123025 -0.23006034 -0.10129994  0.27407625
    0.19301315]
  [ 0.13187079 -0.11723104 -0.19628857  0.02858959  0.11800069
    0.01459896]
  [-0.04808104 -0.0138107   0.08388251 -0.06492199 -0.06727362
   -0.05280718]]

 [[-0.15894619  0.07284678  0.14737347 -0.1276831   0.25255606
    0.09391508]
  [ 0.00786223  0.0871888   0.18400078 -0.19992264  0.26098937
    0.00664918]
  [-0.0803908   0.03582398  0.14990774 -0.09686233  0.14528047
    0.01309295]
  ...
  [ 0.15231846 -0.09140155 -0.04400397  0.13401577  0.05537358
    0.1083131 ]
  [ 0.05493654 -0.05631565  0.06448226  0.03003671 -0.03381815
   -0.20454642]
  [-0.02573016 -0.04879984  0.07133649 -0.05633134 -0.0750716
   -0.02148213]]

 [[-0.12695628  0.0138645

In [27]:
t5_lora_model.crf.modules_to_save.default.decode

<bound method CRF.decode of CRF(num_tags=6)>

---

In [36]:
training_log = pd.DataFrame(trainer.state.log_history)
display(training_log)

Unnamed: 0,loss,learning_rate,epoch,step,train_runtime,train_samples_per_second,train_steps_per_second,total_flos,train_loss,eval_loss,eval_accuracy_metric,eval_precision_metric,eval_recall_metric,eval_f1_metric,eval_matthews_correlation,eval_confusion_matrix,eval_runtime,eval_samples_per_second,eval_steps_per_second
0,1133.2013,0.0,1.0,1,,,,,,,,,,,,,,,
1,,,1.0,1,83.5951,0.108,0.012,4646633000000.0,1133.201294,,,,,,,,,,
2,,,1.0,1,,,,,,382.030853,0.147619,0.147619,0.147619,0.147619,0.0,"[[31, 0, 0, 0, 0, 0], [26, 0, 0, 0, 0, 0], [39...",31.3562,0.096,0.032
3,,,1.0,1,,,,,,382.030853,0.147619,0.147619,0.147619,0.147619,0.0,"[[31, 0, 0, 0, 0, 0], [26, 0, 0, 0, 0, 0], [39...",32.5586,0.092,0.031
4,,,1.0,1,,,,,,382.030853,0.147619,0.147619,0.147619,0.147619,0.0,"[[31, 0, 0, 0, 0, 0], [26, 0, 0, 0, 0, 0], [39...",41.76,0.072,0.024


In [37]:
adapter_location = '/models/testing_1'
training_log['eval_confusion_matrix'] = training_log['eval_confusion_matrix'].apply(lambda x: x.tolist() if type(x)==np.ndarray else None)
t5_lora_model.save_pretrained(ROOT + adapter_location)
training_log.to_csv(ROOT + adapter_location + '/training_log.csv', index=False)
training_log.to_parquet(ROOT + adapter_location + '/training_log.parquet')

---

In [47]:
_ds_index = 2
_ds_type = 'test'

_input_ids_test = t5_tokenizer.decode(dataset_signalp[_ds_type][_ds_index]['input_ids'])
_labels_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['labels'] + [-100]]).to(device)
_attention_mask_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['attention_mask']]).to(device)

_labels_test_decoded = [src.config.label_decoding[x] for x in _labels_test.tolist()[0][:-1]]

print('Iput IDs:\t', _input_ids_test)
print('Labels:\t\t', *_labels_test.tolist()[0])
print('Labels Decoded:\t', *_labels_test_decoded)
print('Attention Mask:\t', *_attention_mask_test.tolist()[0])
print('----')

viterbi_decoding=True

preds = src.model_new.predict_model(
    sequence=_input_ids_test,
    tokenizer=t5_tokenizer,
    model=t5_lora_model,
    labels=_labels_test,
    attention_mask=_attention_mask_test,
    device=device,
    viterbi_decoding=viterbi_decoding,
    )

_result = src.model_new.translate_logits(
    logits=preds.logits,
    viterbi_decoding=viterbi_decoding,
    )

print('Result: \t',* _result)

Iput IDs:	 M A A V I L E R L G A L W V Q N L R G K L A L G I L P Q S H I H T S A S L E I S R K W E K K N K I V Y P P Q L P G E P R R P A E I Y H C R R</s>
Labels:		 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -100
Labels Decoded:	 I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I
Attention Mask:	 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
----
Result: 	 M S S O O L T M S O O O M T S S O O O M M S S O L T T M M M M M S S S S S S S S O O O O O M S I O M I T T I T I S I I O O O O O M S S O O O M


In [None]:
# torch.set_printoptions(threshold=10_000)
# t5_lora_model.custom_classifier.modules_to_save.default.weight