In [None]:
import torch
import sys
sys.path.append('..')
from model.utils import LMHyperParams, SmModel, ModelChoice
from dataset.squad import UltraFeedbackDataModule
from transformers import AutoTokenizer, PreTrainedTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft.tuners.lora.config import LoraConfig
from transformers import TrainingArguments
from trl import DPOTrainer, DPOConfig
from typing import cast
from peft.peft_model import PeftModel
import gc
%load_ext autoreload
%autoreload 2

In [None]:
model_id = "HuggingFaceTB/SmolLM2-135M-Instruct"  # replace with your model id

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
)
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_id)  # type: ignore
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"  # to prevent errors with FA
tokenizer.truncation_side = "left"  # to prevent cutting off last generation

In [None]:
data_module = UltraFeedbackDataModule(2, tokenizer, 1024, 100, False)
# debugger will fail without this
data_module.num_workers = 1
data_module.setup("fit")

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
# max_prompt_length is the maximum length of the prompt and the max_length is the maximum length of the prompt + chosen or rejected response
prompt_length = 1024
max_seq_length = 1512

peft_config = LoraConfig(
    lora_alpha=128,
    lora_dropout=0.05,
    r=256,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
)

args = DPOConfig(
    output_dir="doplhin-dpo",
    num_train_epochs=1,
    per_device_train_batch_size=12,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    learning_rate=5e-5,
    max_grad_norm=0.3,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    logging_steps=25,
    save_steps=500,
    save_total_limit=2,
    evaluation_strategy="steps",
    eval_steps=700,
    bf16=True,
    tf32=True,
    push_to_hub=False,
    report_to="tensorboard",
    # debugger will fail without this
    dataloader_num_workers=0,
    dataset_num_proc=1,
    max_length=max_seq_length,
    max_prompt_length=prompt_length,
    precompute_ref_log_probs=True,
    dataloader_pin_memory=True,
    beta=0.1,
    loss_type="sigmoid",
)


trainer = DPOTrainer(
    model,
    ref_model=None,  # set to none since we use peft
    peft_config=peft_config,
    args=args,
    train_dataset=data_module.train_dataset,
    eval_dataset=data_module.val_dataset,
    tokenizer=tokenizer,  # type: ignore
)

In [None]:
# Make sure not to create a computation graph, or the model will OOM
with torch.no_grad():
    for batch in trainer.get_train_dataloader():
        loss, metrics = trainer.compute_loss(model, batch, True)
        print(loss)
        display(metrics)

In [None]:
trainer.train()