In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from transformers import (
    T5EncoderModel,
    T5Tokenizer,
    T5Config,
    modeling_outputs,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)

import src.config as config

from peft import (
    LoraConfig,
    TaskType
)

import torch
import torch.nn as nn

import peft

import pandas as pd

from datasets import Dataset, DatasetDict

import gc

import time

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
ROOT = '../'

In [None]:
t5_tokenizer = T5Tokenizer.from_pretrained(
        pretrained_model_name_or_path=config.base_model_name,
        do_lower_case=False,
        use_fast=True,
        legacy=False
    )

t5_base_model = T5EncoderModel.from_pretrained(
    pretrained_model_name_or_path=config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    offload_folder=ROOT + "/models/offload"
    )

In [None]:
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=['q', 'k', 'v', 'o'],
    bias="none",
)
t5_lora_model = peft.get_peft_model(t5_base_model, lora_config)
t5_lora_model.print_trainable_parameters()

In [None]:
df_data = pd.read_parquet(ROOT + '/data/processed/5.0_train_full.parquet.gzip')

dataset_signalp = new_model.create_datasets(
    splits=config.splits,
    tokenizer=t5_tokenizer,
    data=df_data,
    dataset_size=config.dataset_size) 

del df_data

In [None]:
embds_2 = t5_lora_model.forward(
    input_ids=torch.tensor([[7, 4, 7, 11, 7]]).to('mps'),
    attention_mask=torch.tensor([[1, 1, 1, 1, 1]]).to('mps')
)

In [None]:
embds_2

In [None]:
dataset_signalp['train'][0]

In [None]:
type(t5_lora_model)

In [None]:
data_collator = DataCollatorForTokenClassification(tokenizer=t5_tokenizer)

training_args = TrainingArguments(
    output_dir='./checkpoints',
    learning_rate=config.lr,
    per_device_train_batch_size=config.batch_size,
    per_device_eval_batch_size=config.batch_size,
    num_train_epochs=config.num_epochs,
    logging_steps=config.logging_steps,
    # save_strategy="steps",
    # save_steps=config.save_steps,
    # evaluation_strategy="steps",
    # eval_steps=1,
    # load_best_model_at_end=True,
    # save_total_limit=5,
    seed=42,
    # fp16=True,
    # deepspeed=deepspeed_config,
    remove_unused_columns=False
)

trainer = Trainer(
    model=t5_lora_model,
    args=training_args,
    train_dataset=dataset_signalp['train'],
    eval_dataset=dataset_signalp['valid'],
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
)

In [None]:
gc.collect()
torch.cuda.empty_cache()
torch.mps.empty_cache()

In [None]:
trainer.train()