## Notebook Setup, Packages, Setup, and Variables
___

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import re
import gc
import os
import math
import copy
import types
import yaml
import sys

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.nn import (
    CrossEntropyLoss,
    MSELoss
)
from torch.utils.data import DataLoader

import evaluate

import transformers
from transformers import (
    AutoModelForTokenClassification,
    AutoConfig,
    T5EncoderModel,
    T5Tokenizer,
    T5PreTrainedModel,
    T5ForConditionalGeneration,
    pipeline,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    set_seed,
    EvalPrediction,
    )
from transformers.modeling_outputs import TokenClassifierOutput

from peft import (
    LoraConfig,
    get_peft_model,
    TaskType,
    get_peft_config,
    PeftModel,
    PeftConfig,
    prepare_model_for_kbit_training
    )

from datasets import Dataset, DatasetDict

import src.config as config

from src.model import (
    get_prottrans_tokenizer_model,
    df_to_dataset,
    inject_linear_layer,
    compute_metrics_fast
    )
from src.utils import get_project_root_path

  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


In [3]:
base_model_name = config.base_model_name
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
ROOT = get_project_root_path()

if config.VERBOSE:
    print("Base Model:\t", base_model_name)
    print("MPS:\t\t", torch.backends.mps.is_available())
    print("Path:\t\t", ROOT)
    print(f"Using device:\t {device}")

pd.set_option('display.max_colwidth',3000)
pd.set_option('display.max_columns', 1000)
pd.set_option('display.max_rows', 1000)

---
## Create Tokenizer and Load Model

In [4]:
model_architecture = T5EncoderModel
t5_tokenizer, t5_base_model = get_prottrans_tokenizer_model(base_model_name, model_architecture)

In [5]:
# sequence_example = 'MSLSRREFVKLCSAGVAGLGISQIY'
# print(len(sequence_example))
# sequence_example = " ".join(list(re.sub(r"[UZOB]", "X", sequence_example)))

# print(t5_tokenizer.decode([4, 7, 7, 10, 8, 1]))
# print(t5_tokenizer.decode([4, 7, 7, 10, 8]))

# t_tensor = torch.tensor([4, 7, 7, 10, 8, 1]).to('mps')
# t_tensor = t_tensor.unsqueeze(0)
# with torch.no_grad():
#     res = t5_base_model(input_ids=t_tensor)
# print(res)
# print(res.last_hidden_state.argmax(-1))

# t_tensor = torch.tensor([4, 7, 7, 10, 8]).to('mps')
# t_tensor = t_tensor.unsqueeze(0)
# with torch.no_grad():
#     res = t5_base_model(input_ids=t_tensor)
# print(res)
# print(res.last_hidden_state.argmax(dim=-1))

# tkns = t5_tokenizer.batch_encode_plus([sequence_example, sequence_example[::-1]], padding=True, return_tensors="pt")
# tkns.to(device)
# print(tkns.input_ids)

# print(tkns.input_ids.shape)
# print(tkns.input_ids[0])
# print(tkns.attention_mask[0][:-1].shape)

# t_tensor_input = torch.tensor([[4, 7, 7, 12, 1, 1]]).to('mps')
# t_tensor_mask = torch.tensor([[1, 1, 1, 1, 1, 0]]).to('mps')

# with torch.no_grad():
#     res = t5_base_model(
#         input_ids=t_tensor_input,
#         attention_mask=t_tensor_mask,
#         )
# print(res.last_hidden_state.shape)
# print(res.last_hidden_state.argmax(dim=-1))

---
## Load Data, Split into Dataset, and Tokenize Sequences

In [6]:
df_data = pd.read_parquet(ROOT + '/data/processed/5.0_train.parquet.gzip')
if config.VERBOSE:
    df_data.head()

In [7]:
ds_train = df_data[df_data.Split.isin([0, 1, 2])].head(config.dataset_size*3 if config.dataset_size else None)
ds_train = df_to_dataset(
    t5_tokenizer,
    ds_train.Sequence.to_list(),
    ds_train.Label.to_list(),
)

ds_validate = df_data[df_data.Split.isin([3])].head(config.dataset_size)
ds_validate = df_to_dataset(
    t5_tokenizer,
    ds_validate.Sequence.to_list(),
    ds_validate.Label.to_list(),
)

ds_test = df_data[df_data.Split.isin([4])].head(config.dataset_size)
ds_test = df_to_dataset(
    t5_tokenizer,
    ds_test.Sequence.to_list(),
    ds_test.Label.to_list()
)

dataset_signalp = DatasetDict({
    'train': ds_train,
    'valid': ds_validate,
    'test': ds_test
        })

del df_data

  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)


In [8]:
if config.VERBOSE:
    print(ds_train)
    print(ds_validate)
    print(ds_test)
    print('----------------------------------')
    print(ds_test[0]['input_ids'])
    print(len(ds_test[0]['input_ids']))
    print(ds_test[0]['attention_mask'])
    print(len(ds_test[0]['attention_mask']))
    print(ds_test[0]['labels'])
    print(len(ds_test[0]['labels']))
    print('----------------------------------')
    print(t5_tokenizer.decode(ds_test[0]['input_ids']))
    print(t5_tokenizer.decode(range(0, 28)))

In [9]:
# sequence_examples = ["PRTEINO", "SEQWENCE"]
# sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples]
# ids = t5_tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest")
# input_ids = torch.tensor(ids['input_ids']).to(device)
# attention_mask = torch.tensor(ids['attention_mask']).to(device)
# print(input_ids)
# print(attention_mask)
# with torch.no_grad():
#     embedding_repr = t5_base_model(input_ids=input_ids,attention_mask=attention_mask)

# emb_0 = embedding_repr.last_hidden_state[0,:7]
# print(f"Shape of per-residue embedding of first sequences: {emb_0.shape}")

---
## Apply LoRA

In [10]:
lora_config = LoraConfig(
        task_type=TaskType.FEATURE_EXTRACTION,
        inference_mode=False,
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        # target_modules=['q', 'k', 'v', 'o'],
        target_modules=['o'],
        bias="none",
    )
t5_lora_model = get_peft_model(t5_base_model, lora_config)
# t5_lora_model = prepare_model_for_kbit_training(t5_lora_model) # add quantization
# t5_lora_model = t5_base_model
t5_lora_model.print_trainable_parameters()
# del t5_base_model

trainable params: 983,040 || all params: 1,209,124,864 || trainable%: 0.0813017769519625


In [11]:
if config.VERBOSE:
    print(t5_lora_model)
    print()
    print(t5_lora_model.forward)

In [12]:
# t_tensor_input = torch.tensor([[4, 7, 7, 12, 10, 8]]).to('mps')
# t_tensor_mask = torch.tensor([[1, 1, 1, 1, 0, 0]]).to('mps')

# t_tensor_input = torch.tensor([[4, 7, 7, 12, 10, 8]]).to('mps')
# t_tensor_mask = torch.tensor([[1, 1, 1, 1, 0, 0]]).to('mps')

# with torch.no_grad():
#     res = t5_lora_model(
#         input_ids=t_tensor,
#         attention_mask=t_tensor_mask,
#         )
# print(res)
# print(res.last_hidden_state.argmax(dim=-1))

---
## Training Loop
https://huggingface.co/docs/peft/task_guides/token-classification-lora

In [13]:
data_collator = DataCollatorForTokenClassification(tokenizer=t5_tokenizer)

In [14]:
t5_lora_model = inject_linear_layer(
    t5_lora_model=t5_lora_model,
    num_labels=len(config.label_decoding),
    dropout_rate=config.dropout_rate
    )

In [15]:
# t5_lora_model.forward

In [16]:
# metric = evaluate.load("glue", "mrpc")
# def compute_metrics_custom(eval_preds: EvalPrediction):
#     print(*eval_preds)
#     logits, labels = eval_preds
#     predictions = np.argmax(logits, axis=-1)
#     return metric.compute(predictions=predictions, references=labels)

In [17]:
training_args = TrainingArguments(
    output_dir='./checkpoints',
    learning_rate=config.lr,
    per_device_train_batch_size=config.batch_size,
    per_device_eval_batch_size=config.batch_size,
    num_train_epochs=config.num_epochs,
    logging_steps=config.logging_steps,
    # save_strategy="steps",
    save_steps=config.save_steps,
    # evaluation_strategy="steps",
    # eval_steps=config.eval_steps,
    # load_best_model_at_end=True,
    # save_total_limit=5,
    seed=42,
    # fp16=True,
    # deepspeed=deepspeed_config
)

trainer = Trainer(
    model=t5_lora_model,
    args=training_args,
    train_dataset=dataset_signalp['train'],
    # eval_dataset=ds_validate,
    data_collator=data_collator,
    # compute_metrics=config.metric
)

In [18]:
# ([set(x) for x in ds_validate['labels']])

In [19]:
# ds_validate['labels'][
#     [set(x) for x in ds_validate['labels']] != {0}
#     ]

In [20]:
# print(next(t5_lora_model.parameters()).is_cuda)
print(t5_lora_model.device)
# print(config.label_decoding)

mps:0


In [21]:
gc.collect()
torch.cuda.empty_cache()
torch.mps.empty_cache()

In [27]:
print(t5_lora_model.config)

T5Config {
  "_name_or_path": "Rostlab/prot_t5_xl_uniref50",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 16384,
  "d_kv": 128,
  "d_model": 1024,
  "decoder_start_token_id": 0,
  "dense_act_fn": "relu",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": false,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 24,
  "num_heads": 32,
  "num_layers": 24,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "transformers_version": "4.35.0",
  "use_cache": true,
  "vocab_size": 128
}



In [22]:
trainer.train()

  0%|          | 0/1 [00:00<?, ?it/s]

abc
True

<class 'peft.peft_model.PeftModelForFeatureExtraction'>
PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): T5EncoderModel(
      (shared): Embedding(128, 1024)
      (encoder): T5Stack(
        (embed_tokens): Embedding(128, 1024)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(in_features=1024, out_features=4096, bias=False)
                  (k): Linear(in_features=1024, out_features=4096, bias=False)
                  (v): Linear(in_features=1024, out_features=4096, bias=False)
                  (o): Linear(
                    in_features=4096, out_features=1024, bias=False
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.05, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=40

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [None]:
# metrics=trainer.evaluate()
# print(metrics)

---

In [None]:
# ds_validate['labels'][9].__len__()

In [None]:
# print(ds_validate)

In [None]:
# target = torch.zeros(4148,70)
# target.shape

In [None]:
# [x + [-1] * (70-len(x)) for x in ds_validate['labels']]

In [None]:
# inlab = torch.tensor([x + [-1] * (70-len(x)) for x in ds_validate['labels']]).to('cpu')
# inlab.shape
# del inlab

In [None]:
# inid = torch.tensor(ds_validate['input_ids']).to(device)
# inid.shape

In [None]:
# gc.collect()
# torch.cuda.empty_cache()

In [None]:
# predictions = t5_lora_model(input_ids=inid[0:1])

In [None]:
# results = []
# for index, _ in enumerate(inid):
#     if index % 100 == 0:
#         torch.cuda.empty_cache()
#     results += t5_lora_model(input_ids=inid[index:index+1]).logits.argmax(dim=-1).tolist()

In [None]:
# len(results)

In [None]:
# correct_labels = [[config.label_decoding[y] for y in x] for x in ds_validate['labels']]

In [None]:
# len(correct_labels)

In [None]:
# correct = 0
# incorrect = 0

# for index, item in enumerate(results):
#     truth = correct_labels[index]
#     prediction = [config.label_decoding[x] for x in item[:len(correct_labels[index])]]
    
#     if index % 50 == 0:
#         print(*truth, sep='')
#         print(*prediction, sep='')
#         print()
    
#     for t, p in zip(truth, prediction):
#         if t == p:
#             correct += 1
#         else:
#             incorrect += 1
    

    
# print("Correct", correct)
# print("Incorrect", incorrect)

In [None]:
# print(correct/(correct+incorrect))

In [None]:
# print(*prediction[0].argmax(dim=-1)[0].tolist())

In [None]:
# print(*[config.label_decoding[x] for x in prediction[0].argmax(dim=-1)[0].tolist()])

In [None]:
# print(prediction[0].argmax(dim=-1))
# print(inlab)
# print(inlab == prediction[0].argmax(dim=-1)[:, :70])

In [None]:
# cust_pred = EvalPrediction(prediction, inlab)

In [None]:
# print(*cust_pred)

In [None]:
# compute_metrics_custom(cust_pred)

---
## Results

In [None]:
result_log = pd.DataFrame(trainer.state.log_history)

In [None]:
display(result_log)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_theme(style="darkgrid")
my_plot = sns.lineplot(data=result_log, x='epoch', y='loss')

In [None]:
fig = my_plot.get_figure()
fig.savefig("./plots/out.png") 

---
## Save Model

In [None]:
t5_lora_model.save_pretrained(ROOT + f'/models/{config.model_name}')

---
## Reload Model and compare weights

In [None]:
label_list = config.label_decoding

In [None]:
adapter_location = '/models/linear_model_v3'
t5_lora_model_config = PeftConfig.from_pretrained(ROOT + adapter_location)
t5_base_model = PeftModel.from_pretrained(
    model=t5_base_model,
    model_id=ROOT+adapter_location,
)
t5_lora_model = inject_linear_layer(t5_base_model, dropout_rate=config.dropout_rate, num_labels=len(label_list))