In [1]:
from datasets import load_from_disk, load_dataset
from dataset_generator import generate_dpo_dataset
from trl import DPOConfig, DPOTrainer
import os
os.environ["WANDB_PROJECT"] = 'CARS'
os.environ["WANDB_NOTEBOOK_NAME"] = "dpo_training.ipynb"

dataset_path = "altas-ai/corrective_dataset_MATH_LLAMA3_8b_ZeroShot_COT"
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
#prompt = "{problem} \nPlease reason step by step, and put your final answer within \\boxed{{}}.\nApproach: "
dataset = load_dataset(dataset_path)
split = dataset['train'].train_test_split(test_size=0.2)

train_dataset = split['train']
eval_dataset = split['test']

# Rename columns
train_dataset = train_dataset.rename_column("original_prompt", "prompt")
train_dataset = train_dataset.rename_column("incorrect_completion", "rejected")
train_dataset = train_dataset.rename_column("correct_completion", "chosen")

eval_dataset = eval_dataset.rename_column("original_prompt", "prompt")
eval_dataset = eval_dataset.rename_column("incorrect_completion", "rejected")
eval_dataset = eval_dataset.rename_column("correct_completion", "chosen")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from unsloth import FastLanguageModel
import torch
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_path,
    load_in_4bit = True,
    dtype=torch.bfloat16
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 8,
    #target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
    #                  "gate_proj", "up_proj", "down_proj",],
    #target_modules=["q_proj", "v_proj"],
    lora_alpha = 16,    
    lora_dropout = 0, # Dropout = 0 is currently optimized
    bias = "none",    # Bias = "none" is currently optimized
    use_gradient_checkpointing = True,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
==((====))==  Unsloth: Fast Llama patching release 2024.7
   \\   /|    GPU: NVIDIA GeForce RTX 3080 Ti Laptop GPU. Max memory: 15.732 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.0+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.26.post1. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Unsloth 2024.7 patched 32 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


In [3]:
dpo_trainer = DPOTrainer(
    model = model,
    ref_model = None,
    args=DPOConfig(
        output_dir="models",

        num_train_epochs=1,
        eval_strategy="steps",
        eval_steps=100,
        save_strategy="steps",
        save_steps=100,
        logging_strategy="steps",
        logging_steps=1,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        beta=0.1,
        
        gradient_accumulation_steps=4, 
        eval_accumulation_steps=4,
        gradient_checkpointing=True,
        
        max_length=2048,
        max_prompt_length=1024,
        max_target_length=1024,
        remove_unused_columns=False,
        truncation_mode="keep_start",
        
        load_best_model_at_end=True,
        save_total_limit=3,
        report_to="wandb",
        run_name="DPO_2",
        bf16 = True,
    ),
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)

Map: 100%|██████████| 2121/2121 [00:03<00:00, 676.19 examples/s]
Map: 100%|██████████| 531/531 [00:00<00:00, 693.11 examples/s]


In [4]:
dpo_trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 2,121 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 1 | Gradient Accumulation steps = 4
\        /    Total batch size = 4 | Total steps = 530
 "-____-"     Number of trainable parameters = 3,407,872
[34m[1mwandb[0m: Currently logged in as: [33mmark-chen-next[0m ([33mteam-quantum[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/530 [00:00<?, ?it/s]Could not estimate the number of tokens of the input, floating-point operations will not be computed
  0%|          | 1/530 [00:06<54:28,  6.18s/it]

{'loss': 0.6931, 'grad_norm': 3.0271146297454834, 'learning_rate': 4.9905660377358493e-05, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/rejected': -95.05059814453125, 'logps/chosen': -93.76498413085938, 'logits/rejected': 0.3297100365161896, 'logits/chosen': 0.2812022268772125, 'epoch': 0.0}


  0%|          | 2/530 [00:11<51:06,  5.81s/it]

{'loss': 0.6907, 'grad_norm': 4.685053825378418, 'learning_rate': 4.9811320754716985e-05, 'rewards/chosen': 0.004506874363869429, 'rewards/rejected': -0.0017758364556357265, 'rewards/accuracies': 0.75, 'rewards/margins': 0.006282712332904339, 'logps/rejected': -116.45826721191406, 'logps/chosen': -131.8994140625, 'logits/rejected': 0.05328919738531113, 'logits/chosen': -0.044286224991083145, 'epoch': 0.0}


  1%|          | 3/530 [00:20<1:03:40,  7.25s/it]

{'loss': 0.6801, 'grad_norm': 5.621777057647705, 'learning_rate': 4.9716981132075476e-05, 'rewards/chosen': 0.001522732200101018, 'rewards/rejected': -0.02511920966207981, 'rewards/accuracies': 0.5, 'rewards/margins': 0.02664194256067276, 'logps/rejected': -95.37469482421875, 'logps/chosen': -95.34652709960938, 'logits/rejected': 0.29182183742523193, 'logits/chosen': 0.2353961169719696, 'epoch': 0.01}


  1%|          | 4/530 [00:32<1:19:11,  9.03s/it]

{'loss': 0.711, 'grad_norm': 4.914802551269531, 'learning_rate': 4.962264150943397e-05, 'rewards/chosen': -0.0047508240677416325, 'rewards/rejected': 0.03016681969165802, 'rewards/accuracies': 0.25, 'rewards/margins': -0.034917641431093216, 'logps/rejected': -141.12496948242188, 'logps/chosen': -139.4482879638672, 'logits/rejected': 0.7066179513931274, 'logits/chosen': 0.4385793209075928, 'epoch': 0.01}


  1%|          | 5/530 [00:39<1:11:56,  8.22s/it]

{'loss': 0.7084, 'grad_norm': 3.577291488647461, 'learning_rate': 4.952830188679246e-05, 'rewards/chosen': -0.032563209533691406, 'rewards/rejected': -0.002468680962920189, 'rewards/accuracies': 0.25, 'rewards/margins': -0.030094526708126068, 'logps/rejected': -121.73657989501953, 'logps/chosen': -135.8009033203125, 'logits/rejected': 0.12698335945606232, 'logits/chosen': 0.035148195922374725, 'epoch': 0.01}


  1%|          | 6/530 [00:45<1:06:28,  7.61s/it]

{'loss': 0.6714, 'grad_norm': 3.205763101577759, 'learning_rate': 4.943396226415095e-05, 'rewards/chosen': 0.021212007850408554, 'rewards/rejected': -0.023925019428133965, 'rewards/accuracies': 0.75, 'rewards/margins': 0.04513702541589737, 'logps/rejected': -80.84663391113281, 'logps/chosen': -123.96571350097656, 'logits/rejected': 0.04646002873778343, 'logits/chosen': 0.014243833720684052, 'epoch': 0.01}


  1%|▏         | 7/530 [00:51<1:01:36,  7.07s/it]

{'loss': 0.679, 'grad_norm': 3.9472076892852783, 'learning_rate': 4.933962264150943e-05, 'rewards/chosen': -0.003227042267099023, 'rewards/rejected': -0.03255443647503853, 'rewards/accuracies': 0.5, 'rewards/margins': 0.02932739444077015, 'logps/rejected': -101.75228881835938, 'logps/chosen': -116.02317810058594, 'logits/rejected': 0.24350719153881073, 'logits/chosen': 0.3324413001537323, 'epoch': 0.01}


  2%|▏         | 8/530 [00:59<1:03:10,  7.26s/it]

{'loss': 0.649, 'grad_norm': 3.621281147003174, 'learning_rate': 4.9245283018867924e-05, 'rewards/chosen': 0.031432151794433594, 'rewards/rejected': -0.059331707656383514, 'rewards/accuracies': 1.0, 'rewards/margins': 0.09076385945081711, 'logps/rejected': -75.12834167480469, 'logps/chosen': -76.53816986083984, 'logits/rejected': 0.415519118309021, 'logits/chosen': 0.07967068254947662, 'epoch': 0.02}


  2%|▏         | 9/530 [01:05<1:00:15,  6.94s/it]

{'loss': 0.6903, 'grad_norm': 3.58050537109375, 'learning_rate': 4.9150943396226415e-05, 'rewards/chosen': -0.021831704303622246, 'rewards/rejected': -0.02833118475973606, 'rewards/accuracies': 0.25, 'rewards/margins': 0.006499480921775103, 'logps/rejected': -110.13899230957031, 'logps/chosen': -139.6927032470703, 'logits/rejected': -0.07249025255441666, 'logits/chosen': -0.10604231804609299, 'epoch': 0.02}


In [None]:
import wandb
wandb.finish()