---
## Setup and Variables

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import gc
import copy
import random

import torch
import torch.nn as nn
import pandas as pd
import numpy as np

import seaborn as sns

import evaluate

from transformers import (
    T5Tokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    TrainerCallback
)

from src.model_new import (
    T5EncoderModelForTokenClassification,
)
import src.config
import src.data
import src.model_new


import peft
from peft import (
    LoraConfig,
)

import src.utils

import sklearn.metrics

In [None]:
print("Base Model:\t", src.config.base_model_name)
print("MPS:\t\t", torch.backends.mps.is_available())
ROOT = src.utils.get_project_root_path()
print("Path:\t\t", ROOT)
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f"Using device:\t {device}")

torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

---
## Create Tokenizer and Load Model

In [None]:
t5_tokenizer = T5Tokenizer.from_pretrained(
        pretrained_model_name_or_path=src.config.base_model_name,
        do_lower_case=False,
        use_fast=True,
        legacy=False
    )

In [None]:
t5_base_model = T5EncoderModelForTokenClassification.from_pretrained(
    pretrained_model_name_or_path=src.config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    custom_num_labels=len(src.config.label_decoding),
    custom_dropout_rate=0.1,
    )

In [None]:
t5_base_model.custom_classifier.weight = nn.Linear(
    in_features=t5_base_model.config.hidden_size,
    out_features=t5_base_model.custom_num_labels).weight

In [None]:
t5_base_model.custom_classifier.weight

In [None]:
# t5_base_model_copy = copy.deepcopy(t5_base_model)

---
## Apply LoRA

In [None]:
lora_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=['q', 'k', 'v', 'o'],
    bias="none",
    modules_to_save=['custom_classifier'],
)

t5_lora_model = peft.get_peft_model(t5_base_model, lora_config)
t5_lora_model.print_trainable_parameters()

In [None]:
t5_lora_model.base_model.custom_classifier.modules_to_save.default.weight

---
## Load Data, Split into Dataset, and Tokenize Sequences

In [None]:
FASTA_FILENAME = '5_SignalP_5.0_Training_set.fasta'
# FASTA_FILENAME = '5_SignalP_5.0_Training_set_testing.fasta'
annotations_name = 'Label' # Choose Type or Label

df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + FASTA_FILENAME))

dataset_signalp = src.model_new.create_datasets(
    splits=src.config.splits,
    tokenizer=t5_tokenizer,
    data=df_data,
    annotations_name=annotations_name,
    dataset_size=src.config.dataset_size,
    encoder=src.config.select_encodings[annotations_name],
    )

del df_data

In [None]:
display(dataset_signalp)
print(dataset_signalp['valid'][0]['input_ids'])
print(dataset_signalp['valid'][0]['labels'])
print(dataset_signalp['valid'][0]['attention_mask'])

---
## Training Loop
https://huggingface.co/docs/peft/task_guides/token-classification-lora

In [None]:
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")
# roc_auc_score_metric = evaluate.load("roc_auc", "multiclass")
matthews_correlation_metric = evaluate.load("matthews_correlation")

In [None]:
def batch_eval_elementwise(predictions: np.ndarray, references: np.ndarray):
    results = {}
    
    if np.isnan(predictions).any():
        print('has nan')
        predictions = np.nan_to_num(predictions)
    
    argmax_predictions = predictions.argmax(axis=-1)
    vals = list((np.array(p)[(r != -100)], np.array(r)[(r != -100)]) for p, r in zip(argmax_predictions.tolist(), references))

    lst1, lst2 = zip(*vals)
    confusion_matrix = sklearn.metrics.confusion_matrix(y_true=np.concatenate(lst1), y_pred=np.concatenate(lst2))
    
    results.update({'accuracy_metric': np.average([accuracy_metric.compute(predictions=x, references=y)['accuracy'] for x, y in vals])})
    results.update({'precision_metric': np.average([precision_metric.compute(predictions=x, references=y, average='micro')['precision'] for x, y in vals])})
    results.update({'recall_metric': np.average([recall_metric.compute(predictions=x, references=y, average='micro')['recall'] for x, y in vals])})
    results.update({'f1_metric': np.average([f1_metric.compute(predictions=x, references=y, average='micro')['f1'] for x, y in vals])})
    # results.update({'roc_auc': [roc_auc_score_metric.compute(prediction_scores=x, references=y, multi_class='ovr', average=None)['roc_auc'] for x, y in zip(softmax_predictions, references)]})
    results.update({'matthews_correlation': np.average([matthews_correlation_metric.compute(predictions=x, references=y, average='micro')['matthews_correlation'] for x, y in vals])})
    results.update({'confusion_matrix': confusion_matrix})

    return results

def compute_metrics(p):
    predictions, references = p
    results = batch_eval_elementwise(predictions=predictions, references=references)
    return results


In [None]:
# t5_lora_model.to(device)

In [None]:
# vals = 0,1,2
# with torch.no_grad():
#     preds = t5_lora_model(
#         input_ids=torch.tensor(dataset_signalp['valid'][vals]['input_ids']).to(device),
#         # attention_mask=torch.tensor(dataset_signalp['test'][0:1]['attention_mask']).to(device),
#         # labels=torch.tensor(dataset_signalp['test'][0:1]['labels']).to(device)
#         )

In [None]:
# _p = preds.logits.cpu().numpy()
# _t = np.array([np.pad(x, (0, 71 - len(x)), mode='constant', constant_values=-100) for x in dataset_signalp['valid'][vals]['labels']])
# metrics = compute_metrics(p=(_p, _t))
# print(metrics)

In [None]:
# ax = sns.heatmap(
#     metrics['confusion_matrix'],
#     annot=True,
#     xticklabels=[src.config.label_decoding[label] for label in range(len(src.config.label_decoding))],
#     yticklabels=[src.config.label_decoding[label] for label in range(len(src.config.label_decoding))],
#     )

# ax.set_title('Confusion Matrix')
# ax.set_xlabel('Actual')
# ax.set_ylabel('Predicted')

In [None]:
# pd.DataFrame(preds.logits.softmax(dim=-1).cpu()[0]).plot()

In [None]:
# roc_auc_score = evaluate.load("roc_auc", "multiclass")
# refs = [1, 0, 1, 2, 2, 0]
# pred_scores = [[0.3, 0.5, 0.2],
#                [0.7, 0.2, 0.1],
#                [0.005, 0.99, 0.005],
#                [0.2, 0.3, 0.5],
#                [0.1, 0.1, 0.8],
#                [0.1, 0.7, 0.2]]
# results = roc_auc_score.compute(references=refs,
#                                 prediction_scores=pred_scores,
#                                 multi_class='ovr')
# print(round(results['roc_auc'], 2))

In [None]:
# print(np.array(refs))
# print(np.array(pred_scores))

In [None]:
# p = preds.logits.cpu().numpy()
# t = np.array(dataset_signalp['valid'][0:1]['labels'])
# t = np.pad(t, ((0, 0), (0, 71 - t.shape[1])), mode='constant', constant_values=-100)
# compute_metrics(p=(p, t))

In [None]:
# x = 10
# truth = t[:, :x][0]
# truth = [1,2,3,4,5,6,5,5,5,5]
# preds = torch.tensor(p[:, :x][0]).softmax(axis=-1).cpu().numpy()
# print(truth)
# print(preds)

In [None]:

# results = roc_auc_score.compute(references=truth,
#                                 prediction_scores=preds,
#                                 multi_class='ovr',
#                                 labels=[1, 2, 3, 4, 5, 6])
# print(round(results['roc_auc'], 2))

In [None]:
# predictions = torch.tensor([
#     # [[1,0,0,0,1], [2,3,2,1,2], [2,2,5,1,3]],
#     [[1,2,0,0,0], [1,float('nan'),1,0,100], [1,4,3,7,10]]
#     ])

# references = torch.tensor([
#     # [0,1,2],
#     [1,4,4]
#     ])

# def softmax(X, axis=0):
#     return np.exp(X)#/np.sum(np.exp(X), axis=axis)

# print(softmax(predictions.numpy(), axis=2))
# print()
# print(predictions.softmax(dim=-1))

# print(predictions.softmax(dim=-1)[0][0])
# compute_metrics((predictions.numpy(), references.numpy()))

In [None]:
data_collator = DataCollatorForTokenClassification(tokenizer=t5_tokenizer)

training_args = TrainingArguments(
    output_dir='./checkpoints',
    learning_rate=src.config.lr,
    per_device_train_batch_size=src.config.batch_size,
    per_device_eval_batch_size=src.config.batch_size,
    num_train_epochs=src.config.num_epochs,
    logging_steps=src.config.logging_steps,
    # save_strategy="steps",
    # save_steps=src.config.save_steps,
    # evaluation_strategy="steps",
    # eval_steps=src.config.eval_steps,
    # gradient_accumulation_steps=accum,
    # load_best_model_at_end=True,
    # save_total_limit=5,
    seed=42,
    # fp16=True,
    # deepspeed=deepspeed_config,
    remove_unused_columns=False,
    label_names=['labels'],
    # debug="underflow_overflow",
)

trainer = Trainer(
    model=t5_lora_model,
    args=training_args,
    train_dataset=dataset_signalp['train'],
    eval_dataset=dataset_signalp['valid'],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# class EvaluateFirstStepCallback(TrainerCallback):
#     def on_step_begin(self, args, state, control, **kwargs):
#         if state.global_step == 0:
#             control.should_evaluate = True
# trainer.add_callback(EvaluateFirstStepCallback())

In [None]:
gc.collect()
torch.cuda.empty_cache()
# torch.mps.empty_cache()

In [None]:
trainer.train()

---

In [None]:
metrics=trainer.evaluate()
print(metrics)

---

In [None]:
training_log = pd.DataFrame(trainer.state.log_history)
display(training_log)

In [None]:
training_log['eval_confusion_matrix'] = training_log['eval_confusion_matrix'].apply(lambda x: x.tolist() if type(x)==np.ndarray else None)

In [None]:
adapter_location = '/models/testing_13'
t5_lora_model.save_pretrained(ROOT + adapter_location)
# training_log.to_csv(ROOT + adapter_location + '/training_log.csv', index=False)
training_log.to_parquet(ROOT + adapter_location + '/training_log.parquet')

---

In [None]:
# for name, param in t5_lora_model.base_model.named_parameters():
#     if "lora" not in name:
#         continue
#     if param.isnan().any():
#         print(f"New parameter {name:<13} | {param.numel():>5} parameters | not updated")
#     else:
#         print(f"New parameter {name:<13} | {param.numel():>5} parameters | updated")

In [None]:
# params_before = dict(t5_base_model_copy.named_parameters())
# for name, param in t5_lora_model.base_model.named_parameters():
#     if "lora" in name:
#         continue

#     name_before = name.partition(".")[-1].replace("original_", "").replace("module.", "").replace("modules_to_save.default.", "")
#     param_before = params_before[name_before]
#     if torch.allclose(param, param_before):
#         print(f"Parameter {name_before:<14} | {param.numel():>7} parameters | not updated")
#     else:
#         print(f"Parameter {name_before:<14} | {param.numel():>7} parameters | updated")

---

In [None]:
def predict_model(sequence: str, tokenizer: T5Tokenizer, model: T5EncoderModelForTokenClassification, attention_mask=None, labels=None,):
    # print('sequence', sequence)
    tokenized_string = tokenizer.encode(sequence, padding=True, truncation=True, return_tensors="pt", max_length=1024)
    # print('tokenized_string', tokenized_string)
    with torch.no_grad():
        output = model(
            input_ids=tokenized_string.to(device),
            labels=labels,
            attention_mask=attention_mask,
            )
    # print('output', output)
    return output

def translate_logits(logits):
    return [src.config.label_decoding[x] for x in logits.argmax(-1).tolist()[0]]

In [None]:
_ds_index = 2
_ds_type = 'test'

_input_ids_test = t5_tokenizer.decode(dataset_signalp[_ds_type][_ds_index]['input_ids'])
_labels_test = torch.tensor(dataset_signalp[_ds_type][_ds_index]['labels'] + [-100]).to(device)
_attention_mask_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['attention_mask']]).to(device)

_labels_test_decoded = [src.config.label_decoding[x] for x in _labels_test.tolist()[:-1]]
print(_input_ids_test)
print(_labels_test)
print(_labels_test_decoded)
print(_attention_mask_test)
print('----')

In [None]:
preds = predict_model(
    sequence=_input_ids_test,
    tokenizer=t5_tokenizer,
    model=t5_lora_model,
    labels=_labels_test,
    attention_mask=_attention_mask_test
    )

In [None]:
preds.logits.max()

In [None]:
_res = translate_logits(preds.logits.cpu().numpy())
print(_res)

In [None]:
# torch.set_printoptions(threshold=10_000)
# t5_lora_model.custom_classifier.modules_to_save.default.weight

---
---

In [None]:
t5_base_model_copy.load_adapter(ROOT+adapter_location)

In [None]:
t5_base_model_copy.to(device)

In [None]:
preds = predict_model(_inids_test, t5_tokenizer, t5_base_model_copy)

In [None]:
_res = translate_logits(preds.logits.cpu().numpy())
print(_res)

In [None]:
print(t5_base_model_copy.custom_classifier.weight)
print(t5_base_model_copy.custom_classifier.bias)

In [None]:
print(t5_lora_model.model.custom_classifier.modules_to_save.default.weight)
print(t5_lora_model.model.custom_classifier.modules_to_save.default.bias)

In [None]:
print(t5_lora_model.model.custom_classifier.original_module.weight)
print(t5_lora_model.model.custom_classifier.original_module.bias)

In [None]:
t5_lora_model.model.custom_classifier

---

In [None]:
params_trained = [(n, m) for n, m in t5_lora_model.named_parameters() if 'original' not in n]
params_reloaded = [(n, m) for n, m in t5_base_model_copy.named_parameters() if 'original' not in n]

for param_trained, param_reloaded in zip(params_trained, params_reloaded):
    if torch.eq(param_trained[1].data, param_reloaded[1].data).all():
        # print(f"Parameter {param_trained[0]} and {param_reloaded[0]} equal")
        pass
    else:
        print(f"Parameter {param_trained[0]} and {param_reloaded[0]} not equal")

In [None]:
df = pd.DataFrame({ 'name': [
'Lower Franconia', 'Upper Franconia',
'Middle Franconia', 'Upper Palatinate', 'Swabia', 'Upper Bavaria', 'Lower Bavaria'
], 'capital': [
'Würzburg', 'Bayreuth', 'Ansbach',
'Regensburg', 'Augsburg', 'Munich', 'Landshut' ],
'population': [
1_320_513, 1_061_929, 1_775_169, 1_112_102, 1_899_442, 4_710_865, 1_244_169
], 'area': [
8_530.99, 7_230.19, 7_245.70, 9_692.23, 9_993.97, 17_529.41, 10_329.87 ]
})
df.index = pd.Index(('LF', 'UF', 'MF', 'UP', 'S', 'UB', 'LB'))

---