In [None]:
%load_ext autoreload
%autoreload 2

import gc
import copy

import torch
import pandas as pd
import numpy as np

import evaluate

from transformers import (
    T5Tokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)

from src.model_new import (
    T5EncoderModelForTokenClassification,
    create_datasets
)
import src.config
import src.data
import src.model_new


import peft
from peft import (
    LoraConfig,
    PeftModel
)

device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
ROOT = '../'
torch.manual_seed(42)

In [None]:
t5_tokenizer = T5Tokenizer.from_pretrained(
        pretrained_model_name_or_path=src.config.base_model_name,
        do_lower_case=False,
        use_fast=True,
        legacy=False
    )

In [None]:
t5_base_model = T5EncoderModelForTokenClassification.from_pretrained(
    pretrained_model_name_or_path=src.config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    custom_num_labels=len(src.config.label_decoding),
    custom_dropout_rate=0.1,
    )

In [None]:
adapter_location = '/models/testing_1'
# t5_lora_model_reloaded = PeftModel.from_pretrained(
#     model = t5_base_model,
#     is_trainable=False,
#     model_id=ROOT+adapter_location,
#     custom_num_labels=len(src.config.label_decoding),
#     custom_dropout_rate=0.1,
# )
t5_base_model.load_adapter(ROOT+adapter_location)

In [None]:
# t5_lora_model_reloaded

In [None]:
def predict_model(sequence: str, tokenizer: T5Tokenizer, model: T5EncoderModelForTokenClassification):
    # print('sequence', sequence)
    tokenized_string = tokenizer.encode(sequence, padding=True, truncation=True, return_tensors="pt", max_length=1024)
    # print('tokenized_string', tokenized_string)
    with torch.no_grad():
        output = model(tokenized_string.to(device))
    # print('output', output)
    return output

In [None]:
# test_seq = 'M A P T L F Q K L F S K R T G L G A P G R D A R D P D C G F S W P L P E F D P S Q I R L I V Y Q D C E R R G R N V L F D S S V K R R N E D I</s>'
test_seq = 'M L C F W R T S H V A V L L I W G V F A A E S S C P D K N Q T M Q N N S S T M T E V N T T V F V Q M G K K A L L C C P S I S L T K V I L I T'

In [None]:
preds = predict_model(test_seq, t5_tokenizer, t5_base_model)

In [None]:
def tranlate_logits(logits):
    return [src.config.label_decoding[x] for x in logits.argmax(-1).tolist()[0]]

In [None]:
_res = tranlate_logits(preds.logits.cpu().numpy())
print(_res)

In [None]:
for index, (param_name, param) in enumerate(t5_base_model.named_parameters()):
    # if index == 11:
    #     break
    if param.requires_grad:
        print(param_name)
    # print(param)
    if param_name in ['custom_classifier.bias', 'custom_classifier.weight']:
        print(param)

In [None]:
type(t5_base_model)

In [None]:
[x for x in t5_base_model.custom_classifier.parameters()]