In [1]:
from datasets import load_from_disk
from dataset_generator import generate_dpo_dataset
from trl import DPOConfig, DPOTrainer
import os
os.environ["WANDB_PROJECT"] = 'CAMR'
os.environ["WANDB_NOTEBOOK_NAME"] = "dpo_training.ipynb"

dataset_path = "dataset/corrective_copy_dataset_MATH_LLAMA3_8b_ZeroShot_COT"
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"

model_output_dir = "models/corrective_copy_dpo"

dataset = load_from_disk(dataset_path)
dataset = generate_dpo_dataset(dataset)

split = dataset.train_test_split(test_size=0.2)

train_dataset = split['train']
eval_dataset = split['test']

[2024-07-12 17:11:20,333] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


In [2]:
from unsloth import FastLanguageModel
import torch
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_path,
    load_in_4bit = True,
    dtype=torch.bfloat16
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 8,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Dropout = 0 is currently optimized
    bias = "none",    # Bias = "none" is currently optimized
    use_gradient_checkpointing = True,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
==((====))==  Unsloth: Fast Llama patching release 2024.7
   \\   /|    GPU: NVIDIA GeForce RTX 4060 Ti. Max memory: 15.611 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.1+cu121. CUDA = 8.9. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.27. FA2 = True]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Unsloth 2024.7 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [3]:


dpo_trainer = DPOTrainer(
    model = model,
    ref_model = None,
    args=DPOConfig(
        output_dir=model_output_dir,
        
        num_train_epochs=3,
        eval_strategy="steps",
        eval_steps=100,
        save_strategy="steps",
        save_steps=100,
        logging_strategy="steps",
        logging_steps=1,
        
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        
        gradient_accumulation_steps=4, 
        eval_accumulation_steps=4,
        gradient_checkpointing=True,
        
        max_length=1300,
        max_prompt_length=300,
        max_target_length=1000,
        remove_unused_columns=False,
        truncation_mode="keep_start",
        
        load_best_model_at_end=True,
        save_total_limit=3,
        
        learning_rate=0.000005,
        
        report_to="wandb",
        run_name="DPO",
        bf16 = True,
    ),
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)

Map:   0%|          | 0/3660 [00:00<?, ? examples/s]

Map:   0%|          | 0/915 [00:00<?, ? examples/s]

In [4]:
dpo_trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 3,660 | Num Epochs = 3
O^O/ \_/ \    Batch size per device = 1 | Gradient Accumulation steps = 4
\        /    Total batch size = 4 | Total steps = 2,745
 "-____-"     Number of trainable parameters = 20,971,520


[34m[1mwandb[0m: Currently logged in as: [33mchengpong1127[0m ([33mteam-quantum[0m). Use [1m`wandb login --relogin`[0m to force relogin


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Rewards/chosen,Rewards/rejected,Rewards/accuracies,Rewards/margins,Logps/rejected,Logps/chosen,Logits/rejected,Logits/chosen
100,0.5403,0.593278,0.255541,0.024361,0.816393,0.231179,-105.13485,-115.795311,-0.153971,-0.203402
200,0.0965,0.434301,0.576965,-0.308383,0.824044,0.885348,-108.462296,-112.581062,-0.104641,-0.143381
300,0.0758,0.374344,0.245502,-1.122327,0.83388,1.367829,-116.601746,-115.895691,-0.050641,-0.081309
400,0.3162,0.344901,-0.17011,-1.937583,0.839344,1.767473,-124.754295,-120.051811,-0.001335,-0.029887
500,0.3059,0.326381,-0.738955,-2.860226,0.848087,2.12127,-133.980728,-125.740265,0.02942,0.000966
600,0.1065,0.301508,-0.689236,-2.94546,0.861202,2.256224,-134.833069,-125.243073,0.064508,0.028438
700,0.43,0.285132,-1.293168,-3.891044,0.866667,2.597876,-144.28894,-131.282394,0.079141,0.039709
800,0.1728,0.26798,-1.375643,-4.110973,0.878689,2.735331,-146.488205,-132.107132,0.101712,0.058228
900,0.2367,0.259899,-2.261848,-5.401323,0.879781,3.139476,-159.391708,-140.969193,0.116246,0.074983
1000,0.0033,0.245963,-2.052123,-5.200417,0.889618,3.148295,-157.382629,-138.871948,0.110899,0.063145


In [None]:
import wandb
wandb.finish()