In [None]:
%load_ext autoreload
%autoreload 2

import gc
import copy

import torch
import torch.nn as nn
import pandas as pd
import numpy as np

import evaluate

from transformers import (
    T5Tokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)

from src.model_new import (
    T5EncoderModelForTokenClassification,
    create_datasets
)
import src.config
import src.data
import src.model_new


import peft
from peft import (
    LoraConfig,
    PeftModel
)

import random

print("Base Model:\t", src.config.base_model_name)
print("MPS:\t\t", torch.backends.mps.is_available())
ROOT = src.utils.get_project_root_path()
print("Path:\t\t", ROOT)
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f"Using device:\t {device}")

torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

In [None]:
t5_tokenizer = T5Tokenizer.from_pretrained(
        pretrained_model_name_or_path=src.config.base_model_name,
        do_lower_case=False,
        use_fast=True,
        legacy=False
    )

In [None]:
t5_base_model = T5EncoderModelForTokenClassification.from_pretrained(
    pretrained_model_name_or_path=src.config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    custom_num_labels=len(src.config.label_decoding),
    custom_dropout_rate=0.1,
    )

In [None]:
# ex_lin_layer = nn.Linear(t5_base_model.config.hidden_size, len(src.config.label_decoding))
# ex_lin_layer.to(device)
# t5_base_model.custom_classifier.weight = ex_lin_layer.weight

In [None]:
adapter_location = '/models/testing_12'

In [None]:
# t5_lora_model_reloaded = PeftModel.from_pretrained(
#     model = t5_base_model,
#     is_trainable=False,
#     model_id=ROOT+adapter_location,
#     custom_num_labels=len(src.config.label_decoding),
#     custom_dropout_rate=0.1,
# )
t5_base_model.load_adapter(ROOT+adapter_location)

In [None]:
# model = T5EncoderModelForTokenClassification.from_pretrained(adapter_location)

---

In [None]:
# FASTA_FILENAME = '5_SignalP_5.0_Training_set.fasta'
FASTA_FILENAME = '5_SignalP_5.0_Training_set_testing.fasta'
annotations_name = 'Label' # Choose Type or Label

df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + FASTA_FILENAME))

dataset_signalp = src.model_new.create_datasets(
    splits=src.config.splits,
    tokenizer=t5_tokenizer,
    data=df_data,
    annotations_name=annotations_name,
    dataset_size=src.config.dataset_size,
    encoder=src.config.select_encodings[annotations_name],
    )

del df_data

In [None]:
def predict_model(sequence: str, tokenizer: T5Tokenizer, model: T5EncoderModelForTokenClassification):
    # print('sequence', sequence)
    tokenized_string = tokenizer.encode(sequence, padding=True, truncation=True, return_tensors="pt", max_length=1024)
    # print('tokenized_string', tokenized_string)
    with torch.no_grad():
        output = model(tokenized_string.to(device))
    # print('output', output)
    return output

def translate_logits(logits):
    return [src.config.label_decoding[x] for x in logits.argmax(-1).tolist()[0]]

In [None]:
_ds_index = 2
_ds_type = 'test'

_inids_test = t5_tokenizer.decode(dataset_signalp[_ds_type][_ds_index]['input_ids'])
_labels_test = dataset_signalp[_ds_type][_ds_index]['labels']
_labels_test_decoded = [src.config.label_decoding[x.item()] for x in _labels_test]
print(_inids_test)
print(_labels_test)
print(_labels_test_decoded)

In [None]:
preds = predict_model(_inids_test, t5_tokenizer, t5_base_model)

In [None]:
# preds

In [None]:
_res = translate_logits(preds.logits.cpu().numpy())
print(_res)

In [None]:
# torch.set_printoptions(threshold=10_000)
# t5_base_model.custom_classifier.weight

In [None]:
pd.read_csv(ROOT + adapter_location + '/training_log.csv')

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

data = pd.read_parquet(ROOT + adapter_location + '/training_log.parquet')['eval_confusion_matrix'][1].tolist()

sns.heatmap(np.array(data), label='Confusion Matrix', annot=True, fmt='d')

In [None]:
# for index, (param_name, param) in enumerate(t5_base_model.named_parameters()):
#     # if index == 11:
#     #     break
#     if param.requires_grad:
#         print(param_name)
#     # print(param)
#     if param_name in ['custom_classifier.bias', 'custom_classifier.weight']:
#         print(param)

In [None]:
# type(t5_base_model)

In [None]:
# [x for x in t5_base_model.custom_classifier.parameters()]

In [None]:
# t5_base_model.custom_classifier.weight