# Fine-tune model with DPO

## Goal

Let's see if using DPO can create a better model.

## Imports

In [1]:
import numpy as np
import pandas as pd

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
)

from trl import DPOConfig, DPOTrainer
from datasets import Dataset

## Load model

In [2]:
model_path = '/home/gbarbadillo/data/deepseekmath'
config = AutoConfig.from_pretrained(model_path)
config.gradient_checkpointing = True

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map='auto',
    torch_dtype="auto", #torch.bfloat16 does not show speed differences
    trust_remote_code=True,
    quantization_config=None,
    config=config
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
# TODO: check pad token on prompt recovery notebook
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token_id = tokenizer.eos_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Load data

In [4]:
df = pd.read_csv('/mnt/hdd0/Kaggle/aimo/external_data/dpo/v0.csv')
df.head()

Unnamed: 0,prompt,chosen,rejected,problem_idx,max_prompt_length,chosen_length,rejected_length
0,You are an expert mathematical programmer. Sol...,```python\nfrom itertools import combinations\...,"```python\nfrom sympy import binomial, symbols...",0,144,123,212
1,User: John computes the sum of the elements of...,```python\nfrom sympy import *\nfrom itertools...,"```python\nfrom sympy import binomial, symbols...",0,96,188,620
2,Below is a math problem you are to solve (non ...,```python\nfrom sympy import *\nfrom itertools...,```python\nfrom sympy import binomial\n\ndef s...,0,169,230,198
3,\nUser: John computes the sum of the elements ...,```python\nfrom itertools import combinations\...,"```python\nfrom sympy import binomial, Rationa...",0,153,177,256
4,User: John computes the sum of the elements of...,```python\nfrom itertools import combinations\...,```python\nfrom sympy import *\n\ndef sum_of_s...,0,89,173,488


In [5]:
unique_problem_ids = df['problem_idx'].unique()
len(unique_problem_ids)

509

In [6]:
train_problem_ids = np.random.choice(unique_problem_ids, int(0.95 * len(unique_problem_ids)), replace=False)
train_df = df[df['problem_idx'].isin(train_problem_ids)]
test_df = df[~df['problem_idx'].isin(train_problem_ids)]
assert len(train_df) + len(test_df) == len(df)
assert set(train_df['problem_idx'].unique()).intersection(set(test_df['problem_idx'].unique())) == set()
print(len(train_df), len(test_df))

10154 507


In [7]:
train_dataset = Dataset.from_pandas(train_df[['prompt', 'chosen', 'rejected']])
eval_dataset = Dataset.from_pandas(test_df[['prompt', 'chosen', 'rejected']])

## Fine-tuning

In [8]:
from peft import LoraConfig

# LoRA config based on https://github.com/ironbar/prompt_recovery/blob/main/notebooks/014_fine-tune_mistral_v2.ipynb
peft_config = LoraConfig(
        lora_alpha=64,
        lora_dropout=0.05, # 0.1, althought Vaca suggested to use 0.05 for big models
        # r: the rank of the update matrices, expressed in int. Lower rank results in smaller update matrices with fewer trainable parameters
        r=16,
        bias="none",
        # target_modules: The modules (for example, attention blocks) to apply the LoRA update matrices.
        target_modules="all-linear",
        task_type="CAUSAL_LM",
)

In [9]:
# https://www.philschmid.de/dpo-align-llms-in-2024-with-trl
# https://huggingface.co/docs/trl/main/en/dpo_trainer
args = DPOConfig(
    output_dir="/mnt/hdd0/Kaggle/aimo/experiments/18_dpo/01_dpo-first-steps",               # directory to save and repository id
    num_train_epochs=1,                     # number of training epochs
    per_device_train_batch_size=2,         # batch size per device during training
    gradient_accumulation_steps=8,          # number of steps before performing a backward/update pass
    per_device_eval_batch_size=4,           # batch size for evaluation
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    learning_rate=2e-5,                     # 10x higher LR than QLoRA paper
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.1,                       # warmup ratio based on QLoRA paper
    lr_scheduler_type="cosine",             # use cosine learning rate scheduler
    logging_steps=10,
    save_steps=100,                         # when to save checkpoint
    # save_total_limit=2,                     # limit the total amount of checkpoints
    eval_strategy="steps",                  # evaluate every n steps
    eval_steps=100,                         # when to evaluate
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    push_to_hub=False,                      # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    model_init_kwargs=None,
    max_length=1024,
    max_prompt_length=300,
    beta=0.1,                               # same as huggingface documentation, default value
    loss_type="sigmoid",                    # default value in huggingface documentation
)

trainer = DPOTrainer(
    model,
    ref_model=None, # set to none since we use peft
    peft_config=peft_config,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)



Map:   0%|          | 0/10154 [00:00<?, ? examples/s]

Map:   0%|          | 0/507 [00:00<?, ? examples/s]

In [10]:
trainer.train()

  0%|          | 0/634 [00:00<?, ?it/s]

Could not estimate the number of tokens of the input, floating-point operations will not be computed


{'loss': 0.6944, 'grad_norm': 3.203125, 'learning_rate': 7.8125e-06, 'rewards/chosen': -0.0004793582484126091, 'rewards/rejected': 4.7817306040087715e-05, 'rewards/accuracies': 0.4424999952316284, 'rewards/margins': -0.0005271749687381089, 'logps/rejected': -91.810302734375, 'logps/chosen': -74.58899688720703, 'logits/rejected': 12.51860237121582, 'logits/chosen': 10.896106719970703, 'epoch': 0.04}
{'loss': 0.688, 'grad_norm': 4.15625, 'learning_rate': 1.5625e-05, 'rewards/chosen': -0.022095371037721634, 'rewards/rejected': -0.0355038158595562, 'rewards/accuracies': 0.5475000143051147, 'rewards/margins': 0.013408441096544266, 'logps/rejected': -92.0755844116211, 'logps/chosen': -77.1809310913086, 'logits/rejected': 12.603387832641602, 'logits/chosen': 11.228765487670898, 'epoch': 0.08}
{'loss': 0.6682, 'grad_norm': 3.84375, 'learning_rate': 1.9981627325622104e-05, 'rewards/chosen': -0.1730453372001648, 'rewards/rejected': -0.24043823778629303, 'rewards/accuracies': 0.5975000262260437, 

  0%|          | 0/127 [00:00<?, ?it/s]

{'eval_loss': 0.5959993004798889, 'eval_runtime': 388.3928, 'eval_samples_per_second': 1.305, 'eval_steps_per_second': 0.327, 'eval_rewards/chosen': -0.11285519599914551, 'eval_rewards/rejected': -0.39046603441238403, 'eval_rewards/accuracies': 0.7467191815376282, 'eval_rewards/margins': 0.2776108384132385, 'eval_logps/rejected': -103.41634368896484, 'eval_logps/chosen': -81.8982162475586, 'eval_logits/rejected': 10.685792922973633, 'eval_logits/chosen': 9.107816696166992, 'epoch': 0.16}




{'loss': 0.5544, 'grad_norm': 3.875, 'learning_rate': 1.9440132825081107e-05, 'rewards/chosen': 0.09427224844694138, 'rewards/rejected': -0.3348385691642761, 'rewards/accuracies': 0.7749999761581421, 'rewards/margins': 0.4291108250617981, 'logps/rejected': -94.19165802001953, 'logps/chosen': -73.97393798828125, 'logits/rejected': 11.793439865112305, 'logits/chosen': 10.369733810424805, 'epoch': 0.2}
{'loss': 0.5003, 'grad_norm': 6.34375, 'learning_rate': 1.8897520431560435e-05, 'rewards/chosen': 0.1144249439239502, 'rewards/rejected': -0.5853338837623596, 'rewards/accuracies': 0.7900000214576721, 'rewards/margins': 0.6997588276863098, 'logps/rejected': -95.39093017578125, 'logps/chosen': -73.02861785888672, 'logits/rejected': 10.666786193847656, 'logits/chosen': 9.123523712158203, 'epoch': 0.24}
{'loss': 0.4456, 'grad_norm': 5.5, 'learning_rate': 1.8186248146866928e-05, 'rewards/chosen': 0.109018474817276, 'rewards/rejected': -0.9023540019989014, 'rewards/accuracies': 0.795000016689300

  0%|          | 0/127 [00:00<?, ?it/s]

{'eval_loss': 0.41919684410095215, 'eval_runtime': 386.4056, 'eval_samples_per_second': 1.312, 'eval_steps_per_second': 0.329, 'eval_rewards/chosen': -0.16413454711437225, 'eval_rewards/rejected': -1.4692895412445068, 'eval_rewards/accuracies': 0.8372703790664673, 'eval_rewards/margins': 1.3051550388336182, 'eval_logps/rejected': -114.20457458496094, 'eval_logps/chosen': -82.4110107421875, 'eval_logits/rejected': 8.588448524475098, 'eval_logits/chosen': 6.895272731781006, 'epoch': 0.32}




{'loss': 0.4166, 'grad_norm': 3.046875, 'learning_rate': 1.6314596443560777e-05, 'rewards/chosen': 0.07803937047719955, 'rewards/rejected': -1.2395579814910889, 'rewards/accuracies': 0.8349999785423279, 'rewards/margins': 1.3175971508026123, 'logps/rejected': -104.66545104980469, 'logps/chosen': -75.92522430419922, 'logits/rejected': 9.41341495513916, 'logits/chosen': 7.779532432556152, 'epoch': 0.35}
{'loss': 0.4082, 'grad_norm': 5.5625, 'learning_rate': 1.5189695737812153e-05, 'rewards/chosen': 0.1666778177022934, 'rewards/rejected': -1.1636799573898315, 'rewards/accuracies': 0.8149999976158142, 'rewards/margins': 1.3303577899932861, 'logps/rejected': -100.0866470336914, 'logps/chosen': -74.34703063964844, 'logits/rejected': 9.636592864990234, 'logits/chosen': 8.3733491897583, 'epoch': 0.39}
{'loss': 0.3531, 'grad_norm': 5.03125, 'learning_rate': 1.3966420038143342e-05, 'rewards/chosen': 0.3202120065689087, 'rewards/rejected': -1.1516671180725098, 'rewards/accuracies': 0.877499997615

  0%|          | 0/127 [00:00<?, ?it/s]

{'eval_loss': 0.4208602011203766, 'eval_runtime': 387.3897, 'eval_samples_per_second': 1.309, 'eval_steps_per_second': 0.328, 'eval_rewards/chosen': -0.5477638840675354, 'eval_rewards/rejected': -1.8813204765319824, 'eval_rewards/accuracies': 0.8175853490829468, 'eval_rewards/margins': 1.3335566520690918, 'eval_logps/rejected': -118.32488250732422, 'eval_logps/chosen': -86.24730682373047, 'eval_logits/rejected': 8.237112998962402, 'eval_logits/chosen': 6.639263153076172, 'epoch': 0.47}




{'loss': 0.2871, 'grad_norm': 6.90625, 'learning_rate': 1.1318921713420691e-05, 'rewards/chosen': 0.061236314475536346, 'rewards/rejected': -1.8166781663894653, 'rewards/accuracies': 0.8999999761581421, 'rewards/margins': 1.8779144287109375, 'logps/rejected': -106.7466049194336, 'logps/chosen': -72.7172622680664, 'logits/rejected': 8.892027854919434, 'logits/chosen': 7.227189540863037, 'epoch': 0.51}
{'loss': 0.2721, 'grad_norm': 2.28125, 'learning_rate': 9.944884618454996e-06, 'rewards/chosen': 0.08680646866559982, 'rewards/rejected': -2.102017402648926, 'rewards/accuracies': 0.9150000214576721, 'rewards/margins': 2.188823699951172, 'logps/rejected': -109.77825927734375, 'logps/chosen': -73.68531036376953, 'logits/rejected': 8.373934745788574, 'logits/chosen': 6.708476543426514, 'epoch': 0.55}
{'loss': 0.2434, 'grad_norm': 3.8125, 'learning_rate': 8.571892281332213e-06, 'rewards/chosen': 0.002565858419984579, 'rewards/rejected': -2.312694787979126, 'rewards/accuracies': 0.915000021457

  0%|          | 0/127 [00:00<?, ?it/s]

{'eval_loss': 0.4043298661708832, 'eval_runtime': 394.5809, 'eval_samples_per_second': 1.285, 'eval_steps_per_second': 0.322, 'eval_rewards/chosen': -1.1567224264144897, 'eval_rewards/rejected': -2.899566411972046, 'eval_rewards/accuracies': 0.8175853490829468, 'eval_rewards/margins': 1.7428438663482666, 'eval_logps/rejected': -128.50735473632812, 'eval_logps/chosen': -92.3368911743164, 'eval_logits/rejected': 6.802318572998047, 'eval_logits/chosen': 5.2856645584106445, 'epoch': 0.63}




{'loss': 0.2285, 'grad_norm': 6.875, 'learning_rate': 5.932633569242e-06, 'rewards/chosen': -0.3080903887748718, 'rewards/rejected': -2.7032525539398193, 'rewards/accuracies': 0.9150000214576721, 'rewards/margins': 2.395162343978882, 'logps/rejected': -117.1850814819336, 'logps/chosen': -77.2620620727539, 'logits/rejected': 7.743144989013672, 'logits/chosen': 6.120285511016846, 'epoch': 0.67}
{'loss': 0.2886, 'grad_norm': 5.625, 'learning_rate': 4.716396535660412e-06, 'rewards/chosen': -0.5579458475112915, 'rewards/rejected': -2.9327666759490967, 'rewards/accuracies': 0.8974999785423279, 'rewards/margins': 2.3748204708099365, 'logps/rejected': -123.11552429199219, 'logps/chosen': -80.95430755615234, 'logits/rejected': 7.356427192687988, 'logits/chosen': 5.726658821105957, 'epoch': 0.71}
{'loss': 0.2415, 'grad_norm': 4.15625, 'learning_rate': 3.6003145949668338e-06, 'rewards/chosen': -0.3022012412548065, 'rewards/rejected': -2.789123773574829, 'rewards/accuracies': 0.9100000262260437, '

  0%|          | 0/127 [00:00<?, ?it/s]

{'eval_loss': 0.39460909366607666, 'eval_runtime': 374.1204, 'eval_samples_per_second': 1.355, 'eval_steps_per_second': 0.339, 'eval_rewards/chosen': -1.1515700817108154, 'eval_rewards/rejected': -2.9324898719787598, 'eval_rewards/accuracies': 0.8097112774848938, 'eval_rewards/margins': 1.7809196710586548, 'eval_logps/rejected': -128.83657836914062, 'eval_logps/chosen': -92.28536224365234, 'eval_logits/rejected': 6.485034465789795, 'eval_logits/chosen': 5.000820636749268, 'epoch': 0.79}




{'loss': 0.291, 'grad_norm': 7.5625, 'learning_rate': 1.7509414761388854e-06, 'rewards/chosen': -0.33489105105400085, 'rewards/rejected': -2.559382677078247, 'rewards/accuracies': 0.887499988079071, 'rewards/margins': 2.224491834640503, 'logps/rejected': -117.00155639648438, 'logps/chosen': -78.22147369384766, 'logits/rejected': 7.056517124176025, 'logits/chosen': 5.574831962585449, 'epoch': 0.83}
{'loss': 0.24, 'grad_norm': 4.78125, 'learning_rate': 1.0527067017923654e-06, 'rewards/chosen': -0.25393348932266235, 'rewards/rejected': -2.6048943996429443, 'rewards/accuracies': 0.9075000286102295, 'rewards/margins': 2.3509607315063477, 'logps/rejected': -112.43512725830078, 'logps/chosen': -76.33489990234375, 'logits/rejected': 7.386599063873291, 'logits/chosen': 5.990445613861084, 'epoch': 0.87}
{'loss': 0.2424, 'grad_norm': 3.09375, 'learning_rate': 5.240753046535396e-07, 'rewards/chosen': -0.3071029782295227, 'rewards/rejected': -2.753527879714966, 'rewards/accuracies': 0.8999999761581

  0%|          | 0/127 [00:00<?, ?it/s]

{'eval_loss': 0.3919623792171478, 'eval_runtime': 374.6222, 'eval_samples_per_second': 1.353, 'eval_steps_per_second': 0.339, 'eval_rewards/chosen': -1.1435599327087402, 'eval_rewards/rejected': -2.9309442043304443, 'eval_rewards/accuracies': 0.8156168460845947, 'eval_rewards/margins': 1.787384033203125, 'eval_logps/rejected': -128.8211212158203, 'eval_logps/chosen': -92.20526123046875, 'eval_logits/rejected': 6.475827217102051, 'eval_logits/chosen': 4.985001087188721, 'epoch': 0.95}




{'loss': 0.2465, 'grad_norm': 2.546875, 'learning_rate': 1.230030851695263e-08, 'rewards/chosen': -0.30363237857818604, 'rewards/rejected': -2.697334051132202, 'rewards/accuracies': 0.9225000143051147, 'rewards/margins': 2.3937015533447266, 'logps/rejected': -119.35763549804688, 'logps/chosen': -77.18180847167969, 'logits/rejected': 7.547821521759033, 'logits/chosen': 5.730653285980225, 'epoch': 0.98}
{'train_runtime': 17792.9492, 'train_samples_per_second': 0.571, 'train_steps_per_second': 0.036, 'train_loss': 0.3783860379589093, 'epoch': 1.0}


TrainOutput(global_step=634, training_loss=0.3783860379589093, metrics={'train_runtime': 17792.9492, 'train_samples_per_second': 0.571, 'train_steps_per_second': 0.036, 'total_flos': 0.0, 'train_loss': 0.3783860379589093, 'epoch': 0.9990151664368722})

## TODO

- [x] Count the tokens in the train set, max length and max prompt length
- [x] Do the training parameters from the example have sense?
- [ ] Shuffle the data?
- [ ] Verify that the model generates text correctly, on the prompt recovery challenge after fine-tuning the model forgot to use the eos token.