## Notebook Setup
___

In [1]:
%load_ext autoreload
%autoreload 2

## Packages
___

In [2]:
import re
import os
import math
import copy
import types
import yaml
import gc
    
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.nn import (
    CrossEntropyLoss,
    MSELoss
)
from torch.utils.data import DataLoader

import evaluate

from transformers import (
    AutoModelForTokenClassification,
    AutoConfig,
    T5EncoderModel,
    T5Tokenizer,
    T5PreTrainedModel,
    T5ForConditionalGeneration,
    pipeline,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    set_seed,
    )
from transformers.modeling_outputs import TokenClassifierOutput

from peft import (
    LoraConfig,
    get_peft_model,
    TaskType,
    get_peft_config,
    PeftModel,
    PeftConfig,
    prepare_model_for_kbit_training
    )

from datasets import Dataset

import src.config as config

from src.model import (
    get_prottrans_tokenizer_model,
    df_to_dataset,
    inject_linear_layer,
    compute_metrics_full,
    compute_metrics_fast
    )
from src.utils import get_project_root_path

  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


---
## Setup and Variables

In [3]:
base_model_name = config.base_model_name
print("Base Model:\t", base_model_name)
print("MPS:\t\t", torch.backends.mps.is_available())
ROOT = get_project_root_path()
print("Path:\t\t", ROOT)
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f"Using device:\t {device}")

Base Model:	 Rostlab/prot_t5_xl_uniref50
MPS:		 True
Path:		 /Users/finnlueth/Developer/gits/prottrans-t5-signalpeptide-prediction
Using device:	 mps


In [4]:
lr = config.lr
batch_size = config.batch_size
num_epochs = config.num_epochs
dropout_rate = config.dropout_rate

label_encoding = config.label_encoding
label_list = config.label_decoding

compute_metrics = compute_metrics_fast

---
## Create Tokenizer and Load Model

In [5]:
model_architecture = T5EncoderModel
t5_tokenizer, t5_base_model = get_prottrans_tokenizer_model(base_model_name, model_architecture)

---
## Load Adapter

In [6]:
adapter_location = '/models/linear_model_v3'
t5_lora_model_config = PeftConfig.from_pretrained(ROOT + adapter_location)

In [7]:
t5_base_model = PeftModel.from_pretrained(
    model=t5_base_model,
    model_id=ROOT+adapter_location,
    is_trainable=False,
    )
# del t5_base_model

In [8]:
t5_lora_model = inject_linear_layer(
    t5_lora_model=t5_base_model,
    num_labels=config.label_decoding.__len__(),
    dropout_rate=config.dropout_rate
    )

In [9]:
t5_lora_model.forward

<bound method injected_forward of PeftModel(
  (base_model): LoraModel(
    (model): T5EncoderModel(
      (shared): Embedding(128, 1024)
      (encoder): T5Stack(
        (embed_tokens): Embedding(128, 1024)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(
                    in_features=1024, out_features=4096, bias=False
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.05, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1024, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=8, out_features=4096, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                    (lora_emb

---
## Load Data

In [10]:
df_data = pd.read_parquet(ROOT + '/data/processed/5.0_train.parquet.gzip')

In [11]:
df_data.head()

Unnamed: 0,Sequence,Label,Split
0,M A P T L F Q K L F S K R T G L G A P G R D A ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0
1,M D F T S L E T T T F E E V V I A L G S N V G ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",1
2,M D D I S G R Q T L P R I N R L L E H V G N P ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",1
3,M L G T V K M E G H E T S D W N S Y Y A D T Q ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4
4,M L G A V K M E G H E P S D W S S Y Y A E P E ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4


In [12]:
# ToDo: Use entire test set
ds_test = df_data[df_data.Split.isin([4])].head(config.dataset_size)
ds_test = df_to_dataset(
    t5_tokenizer,
    ds_test.Sequence.to_list(),
    ds_test.Label.to_list()
)

  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)


In [13]:
ds_test

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 3
})

In [14]:
print(*ds_test['input_ids'][0])
print(*ds_test['attention_mask'][0])
print(*ds_test['labels'][0])

19 4 5 11 6 14 19 9 5 20 9 11 7 10 21 17 7 18 18 3 10 11 16 9 3 18 7 7 6 13 6 7 17 19 17 7 5 4 5 7 19 17 7 19 17 11 18 19 11 19 17 11 19 11 11 7 5 17 19 11 13 3 7 15 17 19 7 18 3 17 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0


In [15]:
input_str = t5_tokenizer.decode(ds_test['input_ids'][0][:-1])
print(input_str)

M L G T V K M E G H E T S D W N S Y Y A D T Q E A Y S S V P V S N M N S G L G S M N S M N T Y M T M N T M T T S G N M T P A S F N M S Y A N


In [16]:
inputs = t5_tokenizer(input_str)
print(inputs)

{'input_ids': [19, 4, 5, 11, 6, 14, 19, 9, 5, 20, 9, 11, 7, 10, 21, 17, 7, 18, 18, 3, 10, 11, 16, 9, 3, 18, 7, 7, 6, 13, 6, 7, 17, 19, 17, 7, 5, 4, 5, 7, 19, 17, 7, 19, 17, 11, 18, 19, 11, 19, 17, 11, 19, 11, 11, 7, 5, 17, 19, 11, 13, 3, 7, 15, 17, 19, 7, 18, 3, 17, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


---
## Load Data

In [17]:
inid = torch.tensor(ds_test['input_ids']).to(device)
print(inid.shape)

torch.Size([3, 71])


In [18]:
t5_lora_model.to(device)

PeftModel(
  (base_model): LoraModel(
    (model): T5EncoderModel(
      (shared): Embedding(128, 1024)
      (encoder): T5Stack(
        (embed_tokens): Embedding(128, 1024)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(
                    in_features=1024, out_features=4096, bias=False
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.05, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1024, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=8, out_features=4096, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                    (lora_embedding_B): ParameterDict()
       

In [42]:
# with torch.no_grad():
results = []
for index, _ in enumerate(inid):
    if index == 10:
        break
    if index % 100 == 0:
        torch.cuda.empty_cache()
    results += t5_lora_model(input_ids=inid[index:index+1]).logits#.argmax(dim=-1).tolist()

abc
True

<class 'peft.peft_model.PeftModel'>
PeftModel(
  (base_model): LoraModel(
    (model): T5EncoderModel(
      (shared): Embedding(128, 1024)
      (encoder): T5Stack(
        (embed_tokens): Embedding(128, 1024)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(
                    in_features=1024, out_features=4096, bias=False
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.05, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1024, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=8, out_features=4096, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                 

In [43]:
len(results)

3

In [44]:
ds_test['labels'][0]

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

In [45]:
results[0]

tensor([[-1.1002e-01, -1.3651e-01, -2.6129e-01, -1.2518e-02, -1.1174e-01,
          6.1251e-02],
        [ 1.1470e-01,  4.1029e-02, -2.5060e-01,  7.3243e-02, -2.8825e-02,
          2.0297e-01],
        [-7.5371e-03, -1.1698e-01, -1.8258e-01, -3.7678e-02, -4.3240e-02,
          1.2707e-01],
        [-9.1174e-02,  1.8264e-02, -1.9967e-01,  6.4711e-02,  2.7490e-02,
          4.7841e-02],
        [-1.0222e-01, -1.0567e-01, -1.1503e-01,  4.3114e-03, -1.7900e-02,
          3.5852e-03],
        [ 2.5795e-01, -4.5228e-01,  9.7303e-03,  1.5691e-01,  2.3870e-02,
          1.2575e-01],
        [-5.2036e-04, -5.6591e-02,  3.6631e-02,  1.3374e-01, -1.2680e-02,
          1.0854e-01],
        [-5.4624e-02, -2.2156e-02, -7.2359e-02, -3.2889e-02, -1.3086e-01,
          2.9044e-01],
        [ 6.1361e-03, -3.0071e-01, -7.0684e-02, -1.1380e-01, -1.8691e-01,
          9.5446e-02],
        [-1.1758e-01,  4.7077e-02, -2.4958e-01,  9.2215e-02, -1.1208e-01,
          1.3471e-01],
        [ 8.6320e-04,  1.4906e

In [23]:
inid[0]

tensor([19,  4,  5, 11,  6, 14, 19,  9,  5, 20,  9, 11,  7, 10, 21, 17,  7, 18,
        18,  3, 10, 11, 16,  9,  3, 18,  7,  7,  6, 13,  6,  7, 17, 19, 17,  7,
         5,  4,  5,  7, 19, 17,  7, 19, 17, 11, 18, 19, 11, 19, 17, 11, 19, 11,
        11,  7,  5, 17, 19, 11, 13,  3,  7, 15, 17, 19,  7, 18,  3, 17,  1],
       device='mps:0')

In [25]:
results[0]

tensor([[-0.1071, -0.1378, -0.1933,  0.1902, -0.0305,  0.1059],
        [ 0.1011,  0.0519, -0.1658,  0.0611,  0.0029,  0.1912],
        [-0.1178, -0.2082, -0.1434,  0.0763, -0.0423,  0.1246],
        [-0.1257, -0.0359, -0.1733, -0.0491, -0.0415, -0.0466],
        [ 0.0497, -0.2000, -0.2202, -0.0820,  0.0453,  0.1234],
        [ 0.2108, -0.3067,  0.0053,  0.1441, -0.1169,  0.0552],
        [-0.1084, -0.1360,  0.0363,  0.0255,  0.0093,  0.1930],
        [-0.0580, -0.0603, -0.1644,  0.0522, -0.2061,  0.2905],
        [-0.1383, -0.4175, -0.1037,  0.1006, -0.0871,  0.0957],
        [-0.1784, -0.0133, -0.2317,  0.0383, -0.1960,  0.1178],
        [-0.0051,  0.2334,  0.0213, -0.0300, -0.0167,  0.1881],
        [-0.0714, -0.1151, -0.2211,  0.1327,  0.1008,  0.1833],
        [-0.1677,  0.0007, -0.0345, -0.0171, -0.1235,  0.2255],
        [ 0.0793, -0.0709, -0.1627, -0.2155, -0.1119,  0.2311],
        [-0.2228, -0.1180, -0.0904,  0.0278, -0.0210,  0.3222],
        [-0.0925, -0.0783, -0.1517,  0.0

In [65]:
ground_truth = [[config.label_decoding[y] for y in x] for x in ds_test['labels']]

In [66]:
print(len(ground_truth))

4146


In [67]:
correct = 0
incorrect = 0

for index, item in enumerate(results):
    truth = ground_truth[index]
    prediction = [config.label_decoding[x] for x in item[:len(ground_truth[index])]]
    
    # if index % 50 == 0:
    print('T: ', *truth, sep='')
    print('P: ', *prediction, sep='')
    print()
    
    for t, p in zip(truth, prediction):
        if t == p:
            correct += 1
        else:
            incorrect += 1
    

    
print("Correct", correct)
print("Incorrect", incorrect)

T: IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII
P: TLSLLTSLSLLOOLLSLLLLLLLLLLLLLLLLLIMLLLOLISLLSILIIILIISOOOLILLOLLSLLLLL

T: IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII
P: LLLLLTOLSSLOLOLLLLLLLSSLOLLLLTISLOLLSLSISSSISOSIILOSLLOOOSSLISOSISLLLL

T: IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII
P: OLLLSLOLOLOLLOLLLLLLLLTLLLISLSLLLLLLSSLLLOLLOLOOOLILMOOIIIILILLILLLOLL

T: IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII
P: SSSSLLLLLLLSSLLMLLLMSSMLLLSLOSLLSTLOLLSSOSLLLLSSLMOMOSOOOOMMOLOOLOLLOL

T: IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII
P: LMLSSLLLLLLLLLLLLLSMTILSLILLOOLOLSSLLSTLLOSLTOLTOOOSSOOOOLTSOLMOLLLLLL

T: IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII
P: LLLLLLSOLSLLOLLLMLLLOLLSLLLLLLTSOILLLLLLLLLLLLLSLLLIOLLLLLILLLLLLLLLLL

T: IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII
P: MSSSMSSLSSLLSSSSSSSSLSSLSSSIS

In [61]:
print(correct/(correct+incorrect))

0.09


In [63]:
gc.collect()

797

---
## Measure Performance

---
---
---
## Save Trainer

Q02742|EUKARYA|NO_SP|4\
MLRTLLRRRLFSYPTKYYFMVLVLSLITFSVLRIHQKPEFVSVRHLELAGENPSSDINCTKVLQGDVNEI\
IIIIIIIIIMMMMMMMMMMMMMMMMMMMMMMMOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO

Q9NQR9|EUKARYA|NO_SP|4\
MDFLHRNGVLIIQHLQKDYRAYYTFLNFMSNVGDPRNIFFIYFPLCFQFNQTVGTKMIWVAVIGDWLNLI\
OOOOOOOOOOOOOOOOOOOOOOOOMMMMMMMMMMMMMMMMMMMMMIIIIIIIIIIIMMMMMMMMMMMMMM

In [None]:
t5_lora_model

In [None]:
device = 'cpu'
t5_lora_model.to(device)

test_set = ds_test.with_format("torch", device=device)

# For token classification we need a data collator here to pad correctly
data_collator = DataCollatorForTokenClassification(t5_tokenizer) 

# Create a dataloader for the test dataset
test_dataloader = DataLoader(test_set, batch_size=16, shuffle = False, collate_fn = data_collator)

# Put the model in evaluation mode
t5_lora_model.eval()

# Make predictions on the test dataset
predictions = []
# We need to collect the batch["labels"] as well, this allows us to filter out all positions with a -100 afterwards
padded_labels = []

counter = 0

with torch.no_grad():
    for batch in test_dataloader:
        print(counter)
        counter += 1
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # Padded labels from the data collator
        padded_labels += batch['labels'].tolist()
        # Add batch results(logits) to predictions, we take the argmax here to get the predicted class
        prediction = t5_lora_model(input_ids=input_ids).logits.argmax(dim=-1).tolist()
        print(prediction)
        predictions += prediction#.argmax(dim=-1).tolist()

In [None]:
index_item = 0

actual = [config.label_decoding[x] for x in test_set['labels'][index_item].tolist()]
print(actual.__len__())
print(*actual)
pred = [config.label_decoding[x] for x in predictions[index_item]]
print(pred.__len__())
print(*pred)