## Notebook Setup
___

In [1]:
%load_ext autoreload
%autoreload 2

## Packages
___

In [2]:
import re
import os
import math
import copy
import types
import yaml

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.nn import (
    CrossEntropyLoss,
    MSELoss
)

import evaluate

from transformers import (
    AutoModelForTokenClassification,
    AutoConfig,
    T5EncoderModel,
    T5Tokenizer,
    T5PreTrainedModel,
    T5ForConditionalGeneration,
    pipeline,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    set_seed,
    )
from transformers.modeling_outputs import TokenClassifierOutput

from peft import (
    LoraConfig,
    get_peft_model,
    TaskType,
    get_peft_config,
    PeftModel,
    PeftConfig,
    prepare_model_for_kbit_training
    )

from datasets import Dataset

import src.config as config

from src.model import (
    get_prottrans_tokenizer_model,
    df_to_dataset,
    inject_linear_layer,
    compute_metrics_full,
    compute_metrics_fast
    )
from src.utils import get_project_root_path

  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


---
## Setup and Variables

In [3]:
base_model_name = 'Rostlab/prot_t5_xl_uniref50'
print("Base Model:\t", base_model_name)
print("MPS:\t\t", torch.backends.mps.is_available())
ROOT = get_project_root_path()
print("Path:\t\t", ROOT)
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f"Using device:\t {device}")

Base Model:	 Rostlab/prot_t5_xl_uniref50
MPS:		 True
Path:		 /Users/finnlueth/Developer/gits/prottrans-t5-signalpeptide-prediction
Using device:	 mps


---
## Create Tokenizer and Load Model

In [4]:
# model_architecture = T5ForConditionalGeneration
model_architecture = T5EncoderModel

t5_tokenizer, t5_base_model = get_prottrans_tokenizer_model(base_model_name, model_architecture)

---
## Load Data, Split into Dataset, and Tokenize Sequences

In [5]:
df_data = pd.read_parquet(ROOT + '/data/processed/5.0_train.parquet.gzip')

In [6]:
df_data.head(5)

Unnamed: 0,Sequence,Label,Split
0,M A P T L F Q K L F S K R T G L G A P G R D A ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train
1,M D F T S L E T T T F E E V V I A L G S N V G ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train
2,M D D I S G R Q T L P R I N R L L E H V G N P ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train
3,M L G T V K M E G H E T S D W N S Y Y A D T Q ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",test
4,M L G A V K M E G H E P S D W S S Y Y A E P E ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",test


In [7]:
ds_train = df_data[df_data.Split == 'train']

ds_train = df_to_dataset(
    t5_tokenizer,
    ds_train.Sequence.to_list()[:10],
    ds_train.Label.to_list()[:10],
)

In [8]:
# ToDo: Use entire test set
ds_test = df_data[df_data.Split == 'test']

ds_test = df_to_dataset(
    t5_tokenizer,
    ds_test.Sequence.to_list()[:5],
    ds_test.Label.to_list()[:5]
)

---
## Apply LoRA

In [9]:
lora_config = LoraConfig(
        # task_type=TaskType.TOKEN_CLS,
        inference_mode=False,
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        # target_modules=['q', 'k', 'v', 'o'],
        target_modules=['o'],
        bias="none",
    )

In [10]:
t5_lora_model = get_peft_model(t5_base_model, lora_config)
# t5_lora_model = prepare_model_for_kbit_training(t5_lora_model)

In [11]:
t5_lora_model.print_trainable_parameters()

trainable params: 983,040 || all params: 1,209,124,864 || trainable%: 0.0813017769519625


---
## Model

In [12]:
t5_lora_model = inject_linear_layer(t5_lora_model)

---
## DeepSpeed

In [13]:
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "9994"  # modify if RuntimeError: Address already in use
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"

---
## Training Loop
https://huggingface.co/docs/peft/task_guides/token-classification-lora

In [14]:
label_encoding = config.label_encoding
label_list = config.label_decoding

compute_metrics = compute_metrics_fast

In [15]:
lr = config.lr
batch_size = config.batch_size
num_epochs = config.num_epochs
dropout_rate = config.dropout_rate

In [16]:
data_collator = DataCollatorForTokenClassification(tokenizer=t5_tokenizer)

In [17]:
with open(ROOT+'/deepspeed_config.yaml', 'r') as file:
    deepspeed_config = yaml.safe_load(file)

In [18]:
training_args = TrainingArguments(
    output_dir='./',
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=10,
    # save_strategy="steps",
    # save_steps=100,
    save_strategy='no',
    load_best_model_at_end=True,
    save_total_limit=3,
    seed=42,
    # deepspeed=deepspeed_config
)

In [19]:
trainer = Trainer(
    model=t5_lora_model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_test, #make sure to change to actual eval later
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [20]:
torch.tensor(ds_train[0]['input_ids']).to(device).size()

torch.Size([71])

In [21]:
t5_lora_model = t5_lora_model.to('cpu')

In [22]:
num_labels = label_list.__len__()
t5_lora_model.dropout = nn.Dropout(dropout_rate)
t5_lora_model.num_labels = num_labels

t5_lora_model.get_base_model().dropout = nn.Dropout(dropout_rate)
t5_lora_model.get_base_model().classifier = nn.Linear(
    in_features=t5_lora_model.get_base_model().config.hidden_size,
    out_features=label_list.__len__()
    )

In [23]:
trainer.train()

  0%|          | 0/10 [00:00<?, ?it/s]

{'train_runtime': 23791.994, 'train_samples_per_second': 0.004, 'train_steps_per_second': 0.0, 'train_loss': 1.8767080307006836, 'epoch': 10.0}


TrainOutput(global_step=10, training_loss=1.8767080307006836, metrics={'train_runtime': 23791.994, 'train_samples_per_second': 0.004, 'train_steps_per_second': 0.0, 'train_loss': 1.8767080307006836, 'epoch': 10.0})

---
## Save Model

In [None]:
# t5_lora_model()

---
## Make Inference

In [25]:
from torch.utils.data import DataLoader

In [26]:
device = 'mps'
t5_lora_model.to(device)

test_set = ds_test.with_format("torch", device=device)

# For token classification we need a data collator here to pad correctly
data_collator = DataCollatorForTokenClassification(t5_tokenizer) 

# Create a dataloader for the test dataset
test_dataloader = DataLoader(test_set, batch_size=16, shuffle = False, collate_fn = data_collator)

# Put the model in evaluation mode
t5_lora_model.eval()

# Make predictions on the test dataset
predictions = []
# We need to collect the batch["labels"] as well, this allows us to filter out all positions with a -100 afterwards
padded_labels = []

with torch.no_grad():
    for batch in test_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # Padded labels from the data collator
        padded_labels += batch['labels'].tolist()
        # Add batch results(logits) to predictions, we take the argmax here to get the predicted class
        predictions += t5_lora_model(input_ids=input_ids, attention_mask=attention_mask).logits.argmax(dim=-1).tolist()

In [42]:
test_set

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 5
})

In [47]:
print(*[config.label_decoding[x] for x in test_set[0]['labels'].tolist()])

I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I


In [33]:
print(*[[config.label_decoding[y] for y in x] for x in predictions][4])

S I I I I I I I I I I I I I I I I I M I I I I T I M I I I I I I I I I I I I I O I I I M I I I B I I I I I I O I I I O I I I I I I M B I I S M


---
## Measure Performance

In [None]:
# base_model_test = T5ForConditionalGeneration.from_pretrained(
#     base_model_name,
#     device_map='auto',
#     offload_folder='./offload',
#     load_in_8bit=False
# )
# tsss_ids = t5_tokenizer('M A P T L F Q K L F S K R T G L G A P G R D A', return_tensors="pt").input_ids.to(device)
# tsss_mask = t5_tokenizer('M A P T L F Q K L F S K R T G L G A P G R D A', return_tensors="pt").attention_mask.to(device)
# base_model_test(input_ids=tsss_ids, decoder_input_ids=tsss_ids, attention_mask=tsss_mask)