In [None]:
import torch
from datasets import load_dataset
from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoTokenizer,
    TrainingArguments,
)
from trl import DPOTrainer, DPOConfig
import json
from datasets import Dataset
from sklearn.model_selection import train_test_split

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# we load the tokenizer and add the special tokens acording to the STF model
tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
tokenizer.add_special_tokens({"pad_token": "<pad>"})
tokenizer.add_tokens(["<bot>: "])

In [None]:
# we create an instance of the model
policy_model = AutoModelForCausalLM.from_pretrained(
    "gpt2-large",
    device_map="auto",
    trust_remote_code=True,
    low_cpu_mem_usage=True
).to(device)

In [None]:
# we manage the tokenizer and finally load the SFT weights

policy_model.config.pad_token_id = tokenizer.pad_token_id

policy_model.config.use_cache = False

policy_model.resize_token_embeddings(len(tokenizer))

policy_model.load_state_dict(torch.load("weights/model_state_2_large_v2.pt", map_location=device))

In [None]:
# we prepare the model for training with PEFT to get a better performance
peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules= ["c_attn", "c_proj", "c_fc"]
)

In [None]:
# we define the arguments for the DPO training
training_arguments = DPOConfig(
    output_dir="./results",
        evaluation_strategy="steps",
        do_eval=True,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        per_device_eval_batch_size=2,
        log_level="debug",
        save_steps=10,
        logging_steps=10,
        learning_rate=1e-5,
        eval_steps=20,
        num_train_epochs=10,
        max_steps=200,
        warmup_steps=20,
        lr_scheduler_type="cosine",
        remove_unused_columns=False
)

In [None]:
# we load the dataset
dict_dataset = json.load(open("data/dpo_dataset_RL.json", "r"))
dataset = Dataset.from_dict(dict_dataset)
train_set, val_set = train_test_split(dataset, test_size=0.2)
train_set = Dataset.from_dict(train_set)
val_set = Dataset.from_dict(val_set)

In [None]:
# we create the trainer and start the training
trainer = DPOTrainer(
    policy_model,
    ref_model=None,
    args=training_arguments,
    beta=0.1,
    peft_config=peft_config,
    train_dataset=train_set,
    eval_dataset=val_set,
    tokenizer=tokenizer,
)
trainer.train()

In [None]:
# save the state dict of the model
torch.save(policy_model.state_dict(), "weights/model_post_DPO.pt")