## Notebook Setup
___

In [None]:
%load_ext autoreload
%autoreload 2

## Packages
___

In [None]:
import re
import os
import math
import copy
import types
import yaml
import gc
    
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.nn import (
    CrossEntropyLoss,
    MSELoss
)
from torch.utils.data import DataLoader

import evaluate

from transformers import (
    AutoModelForTokenClassification,
    AutoConfig,
    T5EncoderModel,
    T5Tokenizer,
    T5PreTrainedModel,
    T5ForConditionalGeneration,
    pipeline,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    set_seed,
    )
from transformers.modeling_outputs import TokenClassifierOutput

from peft import (
    LoraConfig,
    get_peft_model,
    TaskType,
    get_peft_config,
    PeftModel,
    PeftConfig,
    prepare_model_for_kbit_training
    )

from datasets import Dataset

import src.config as config

from src.model import (
    get_prottrans_tokenizer_model,
    df_to_dataset,
    inject_linear_layer,
    compute_metrics_full,
    compute_metrics_fast
    )
from src.utils import get_project_root_path

---
## Setup and Variables

In [None]:
base_model_name = config.base_model_name
print("Base Model:\t", base_model_name)
print("MPS:\t\t", torch.backends.mps.is_available())
ROOT = get_project_root_path()
print("Path:\t\t", ROOT)
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f"Using device:\t {device}")

In [None]:
lr = config.lr
batch_size = config.batch_size
num_epochs = config.num_epochs
dropout_rate = config.dropout_rate

label_encoding = config.label_encoding
label_list = config.label_decoding

compute_metrics = compute_metrics_fast

---
## Create Tokenizer and Load Model

In [5]:
model_architecture = T5EncoderModel
t5_tokenizer, t5_base_model = get_prottrans_tokenizer_model(base_model_name, model_architecture)

---
## Load Adapter

In [6]:
adapter_location = '/models/linear_model_v3'
t5_lora_model_config = PeftConfig.from_pretrained(ROOT + adapter_location)

In [7]:
t5_base_model = PeftModel.from_pretrained(
    model=t5_base_model,
    model_id=ROOT+adapter_location,
    is_trainable=False,
    )
# del t5_base_model

In [8]:
t5_lora_model = inject_linear_layer(
    t5_lora_model=t5_base_model,
    num_labels=config.label_decoding.__len__(),
    dropout_rate=config.dropout_rate
    )

In [9]:
# t5_lora_model.forward

In [15]:
t5_lora_model.encoder.block[4].layer[0].SelfAttention.v.lora_A.default.weight

Parameter containing:
tensor([[ 0.0195,  0.0400, -0.0119,  ...,  0.0237, -0.0318, -0.0082],
        [-0.0312, -0.0007,  0.0204,  ..., -0.0371, -0.0603,  0.0122],
        [-0.0065, -0.0192,  0.0001,  ..., -0.0070, -0.0275, -0.0091],
        ...,
        [-0.0391, -0.0258,  0.0162,  ..., -0.0222, -0.0073, -0.0247],
        [ 0.0190, -0.0202,  0.0442,  ...,  0.0069, -0.0004, -0.0178],
        [ 0.0224, -0.0318, -0.0377,  ..., -0.0311,  0.0143,  0.0004]],
       device='mps:0')

In [17]:
 for name, param in t5_lora_model.base_model.named_parameters():
    if "lora" not in name:
        print(f"New parameter {name:<13} | {param.numel():>5} parameters")
        continue
    if param.isnan().any():
        print(f"New parameter {name:<13} | {param.numel():>5} parameters | not updated")
    else:
        print(f"New parameter {name:<13} | {param.numel():>5} parameters | updated")

New parameter model.shared.weight | 131072 parameters
New parameter model.encoder.block.0.layer.0.SelfAttention.q.weight | 4194304 parameters
New parameter model.encoder.block.0.layer.0.SelfAttention.q.lora_A.default.weight |  8192 parameters | updated
New parameter model.encoder.block.0.layer.0.SelfAttention.q.lora_B.default.weight | 32768 parameters | updated
New parameter model.encoder.block.0.layer.0.SelfAttention.k.weight | 4194304 parameters
New parameter model.encoder.block.0.layer.0.SelfAttention.k.lora_A.default.weight |  8192 parameters | updated
New parameter model.encoder.block.0.layer.0.SelfAttention.k.lora_B.default.weight | 32768 parameters | updated
New parameter model.encoder.block.0.layer.0.SelfAttention.v.weight | 4194304 parameters
New parameter model.encoder.block.0.layer.0.SelfAttention.v.lora_A.default.weight |  8192 parameters | updated
New parameter model.encoder.block.0.layer.0.SelfAttention.v.lora_B.default.weight | 32768 parameters | updated
New parameter mo

---
## Load Data

In [None]:
df_data = pd.read_parquet(ROOT + '/data/processed/5.0_train.parquet.gzip')

In [None]:
df_data.head()

In [None]:
# ToDo: Use entire test set
ds_test = df_data[df_data.Split.isin([4])].head(config.dataset_size)
ds_test = df_to_dataset(
    t5_tokenizer,
    ds_test.Sequence.to_list(),
    ds_test.Label.to_list()
)

In [None]:
ds_test

In [None]:
print(*ds_test['input_ids'][0])
print(*ds_test['attention_mask'][0])
print(*ds_test['labels'][0])

In [None]:
input_str = t5_tokenizer.decode(ds_test['input_ids'][0][:-1])
print(input_str)

In [None]:
inputs = t5_tokenizer(input_str)
print(inputs)

---
## Load Data

In [None]:
inid = torch.tensor(ds_test['input_ids']).to(device)
print(inid.shape)

In [None]:
t5_lora_model.to(device)

In [None]:
# with torch.no_grad():
results = []
for index, _ in enumerate(inid):
    if index == 10:
        break
    if index % 100 == 0:
        torch.cuda.empty_cache()
    results += t5_lora_model(input_ids=inid[index:index+1]).logits#.argmax(dim=-1).tolist()

In [None]:
len(results)

In [None]:
ds_test['labels'][0]

In [None]:
results[0]

In [None]:
inid[0]

In [None]:
results[0]

In [None]:
ground_truth = [[config.label_decoding[y] for y in x] for x in ds_test['labels']]

In [None]:
print(len(ground_truth))

In [None]:
correct = 0
incorrect = 0

for index, item in enumerate(results):
    truth = ground_truth[index]
    prediction = [config.label_decoding[x] for x in item[:len(ground_truth[index])]]
    
    # if index % 50 == 0:
    print('T: ', *truth, sep='')
    print('P: ', *prediction, sep='')
    print()
    
    for t, p in zip(truth, prediction):
        if t == p:
            correct += 1
        else:
            incorrect += 1
    

    
print("Correct", correct)
print("Incorrect", incorrect)

In [None]:
print(correct/(correct+incorrect))

In [None]:
gc.collect()

---
## Measure Performance

---
---
---
## Save Trainer

Q02742|EUKARYA|NO_SP|4\
MLRTLLRRRLFSYPTKYYFMVLVLSLITFSVLRIHQKPEFVSVRHLELAGENPSSDINCTKVLQGDVNEI\
IIIIIIIIIMMMMMMMMMMMMMMMMMMMMMMMOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO

Q9NQR9|EUKARYA|NO_SP|4\
MDFLHRNGVLIIQHLQKDYRAYYTFLNFMSNVGDPRNIFFIYFPLCFQFNQTVGTKMIWVAVIGDWLNLI\
OOOOOOOOOOOOOOOOOOOOOOOOMMMMMMMMMMMMMMMMMMMMMIIIIIIIIIIIMMMMMMMMMMMMMM

In [None]:
t5_lora_model

In [None]:
device = 'cpu'
t5_lora_model.to(device)

test_set = ds_test.with_format("torch", device=device)

# For token classification we need a data collator here to pad correctly
data_collator = DataCollatorForTokenClassification(t5_tokenizer) 

# Create a dataloader for the test dataset
test_dataloader = DataLoader(test_set, batch_size=16, shuffle = False, collate_fn = data_collator)

# Put the model in evaluation mode
t5_lora_model.eval()

# Make predictions on the test dataset
predictions = []
# We need to collect the batch["labels"] as well, this allows us to filter out all positions with a -100 afterwards
padded_labels = []

counter = 0

with torch.no_grad():
    for batch in test_dataloader:
        print(counter)
        counter += 1
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # Padded labels from the data collator
        padded_labels += batch['labels'].tolist()
        # Add batch results(logits) to predictions, we take the argmax here to get the predicted class
        prediction = t5_lora_model(input_ids=input_ids).logits.argmax(dim=-1).tolist()
        print(prediction)
        predictions += prediction#.argmax(dim=-1).tolist()

In [None]:
index_item = 0

actual = [config.label_decoding[x] for x in test_set['labels'][index_item].tolist()]
print(actual.__len__())
print(*actual)
pred = [config.label_decoding[x] for x in predictions[index_item]]
print(pred.__len__())
print(*pred)