# Fine-tune model with DPO

## Goal

Let's see if using DPO can create a better model.

## Imports

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
import numpy as np
import pandas as pd

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
)

from trl import DPOConfig, DPOTrainer
from datasets import Dataset

from peft import LoraConfig

## Load model

In [3]:
model_path = '/home/gbarbadillo/data/deepseekmath'
config = AutoConfig.from_pretrained(model_path)
config.gradient_checkpointing = True

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map='auto',
    torch_dtype="auto", #torch.bfloat16 does not show speed differences
    trust_remote_code=True,
    quantization_config=None,
    config=config
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
# TODO: check pad token on prompt recovery notebook
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token_id = tokenizer.eos_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Load data

In [10]:
df = pd.read_csv('/mnt/hdd0/Kaggle/aimo/external_data/dpo/v1.csv')
df.head()

Unnamed: 0,prompt,chosen,rejected,problem_idx,max_prompt_length,chosen_length,rejected_length
0,User: John computes the sum of the elements of...,from sympy import *\n\ndef sum_of_subset_sums(...,from itertools import combinations\n\ndef sum_...,0,92,166,159
1,You are an expert mathematical programmer. Sol...,from itertools import combinations\n\ndef sum_...,from sympy import binomial\n\ndef sum_of_subse...,0,147,168,199
2,\nUser: John computes the sum of the elements ...,from itertools import combinations\n\ndef sum_...,"from sympy import binomial, Rational, simplify...",0,156,118,405
3,User: John computes the sum of the elements of...,from itertools import combinations\n\ndef sum_...,"from sympy import binomial, summation, symbols...",0,99,179,459
4,Below is a math problem you are to solve (non ...,from itertools import combinations\n\ndef sum_...,from itertools import combinations\n\ndef sum_...,0,100,163,189


In [11]:
unique_problem_ids = df['problem_idx'].unique()
len(unique_problem_ids)

509

In [14]:
np.random.seed(7)
train_problem_ids = np.random.choice(unique_problem_ids, int(0.95 * len(unique_problem_ids)), replace=False)
train_df = df[df['problem_idx'].isin(train_problem_ids)]
test_df = df[~df['problem_idx'].isin(train_problem_ids)]
assert len(train_df) + len(test_df) == len(df)
assert set(train_df['problem_idx'].unique()).intersection(set(test_df['problem_idx'].unique())) == set()
print(len(train_df), len(test_df))

10217 444


In [15]:
# shuffle the train dataset
train_dataset = Dataset.from_pandas(train_df[['prompt', 'chosen', 'rejected']].sample(frac=1))
eval_dataset = Dataset.from_pandas(test_df[['prompt', 'chosen', 'rejected']])

## Fine-tuning

In [9]:
raise

RuntimeError: No active exception to reraise

### First experiment

In [None]:
from peft import LoraConfig

# LoRA config based on https://github.com/ironbar/prompt_recovery/blob/main/notebooks/014_fine-tune_mistral_v2.ipynb
peft_config = LoraConfig(
        lora_alpha=64,
        lora_dropout=0.05, # 0.1, althought Vaca suggested to use 0.05 for big models
        # r: the rank of the update matrices, expressed in int. Lower rank results in smaller update matrices with fewer trainable parameters
        r=16,
        bias="none",
        # target_modules: The modules (for example, attention blocks) to apply the LoRA update matrices.
        target_modules="all-linear",
        task_type="CAUSAL_LM",
)

In [None]:
# https://www.philschmid.de/dpo-align-llms-in-2024-with-trl
# https://huggingface.co/docs/trl/main/en/dpo_trainer
args = DPOConfig(
    output_dir="/mnt/hdd0/Kaggle/aimo/experiments/18_dpo/01_dpo-first-steps",               # directory to save and repository id
    num_train_epochs=1,                     # number of training epochs
    per_device_train_batch_size=2,         # batch size per device during training
    gradient_accumulation_steps=8,          # number of steps before performing a backward/update pass
    per_device_eval_batch_size=4,           # batch size for evaluation
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    learning_rate=2e-5,                     # 10x higher LR than QLoRA paper
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.1,                       # warmup ratio based on QLoRA paper
    lr_scheduler_type="cosine",             # use cosine learning rate scheduler
    logging_steps=10,
    save_steps=100,                         # when to save checkpoint
    # save_total_limit=2,                     # limit the total amount of checkpoints
    eval_strategy="steps",                  # evaluate every n steps
    eval_steps=100,                         # when to evaluate
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    push_to_hub=False,                      # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    model_init_kwargs=None,
    max_length=1024,
    max_prompt_length=300,
    beta=0.1,                               # same as huggingface documentation, default value
    loss_type="sigmoid",                    # default value in huggingface documentation
)

trainer = DPOTrainer(
    model,
    ref_model=None, # set to none since we use peft
    peft_config=peft_config,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

### 2. Shuffle the data

In [None]:
from peft import LoraConfig

# LoRA config based on https://github.com/ironbar/prompt_recovery/blob/main/notebooks/014_fine-tune_mistral_v2.ipynb
peft_config = LoraConfig(
        lora_alpha=64,
        lora_dropout=0.05, # 0.1, althought Vaca suggested to use 0.05 for big models
        # r: the rank of the update matrices, expressed in int. Lower rank results in smaller update matrices with fewer trainable parameters
        r=16,
        bias="none",
        # target_modules: The modules (for example, attention blocks) to apply the LoRA update matrices.
        target_modules="all-linear",
        task_type="CAUSAL_LM",
)

In [None]:
# https://www.philschmid.de/dpo-align-llms-in-2024-with-trl
# https://huggingface.co/docs/trl/main/en/dpo_trainer
args = DPOConfig(
    output_dir="/mnt/hdd0/Kaggle/aimo/experiments/18_dpo/02_shuffle_train_set",               # directory to save and repository id
    num_train_epochs=1,                     # number of training epochs
    per_device_train_batch_size=2,         # batch size per device during training
    gradient_accumulation_steps=8,          # number of steps before performing a backward/update pass
    per_device_eval_batch_size=4,           # batch size for evaluation
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    learning_rate=2e-5,                     # 10x higher LR than QLoRA paper
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.1,                       # warmup ratio based on QLoRA paper
    lr_scheduler_type="cosine",             # use cosine learning rate scheduler
    logging_steps=10,
    save_steps=100,                         # when to save checkpoint
    # save_total_limit=2,                     # limit the total amount of checkpoints
    eval_strategy="steps",                  # evaluate every n steps
    eval_steps=100,                         # when to evaluate
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    push_to_hub=False,                      # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    model_init_kwargs=None,
    max_length=1024,
    max_prompt_length=300,
    beta=0.1,                               # same as huggingface documentation, default value
    loss_type="sigmoid",                    # default value in huggingface documentation
)

trainer = DPOTrainer(
    model,
    ref_model=None, # set to none since we use peft
    peft_config=peft_config,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)
trainer.train()

### 3. Remove tf32

Does it train faster? No, but memory consumption is half. Thus I can train on a single GPU or double the batch size.

If I train on a single gpu the training is faster, each iteration takes 18-20 seconds instead of 26. This happens because it avoids the comunication between gpus and the gpu is all the time working.

I have tried using data parallelism but does not work.

```python
import torch
model = torch.nn.DataParallel(model)
```

In [None]:
# LoRA config based on https://github.com/ironbar/prompt_recovery/blob/main/notebooks/014_fine-tune_mistral_v2.ipynb
peft_config = LoraConfig(
        lora_alpha=64,
        lora_dropout=0.05, # 0.1, althought Vaca suggested to use 0.05 for big models
        # r: the rank of the update matrices, expressed in int. Lower rank results in smaller update matrices with fewer trainable parameters
        r=16,
        bias="none",
        # target_modules: The modules (for example, attention blocks) to apply the LoRA update matrices.
        target_modules="all-linear",
        task_type="CAUSAL_LM",
)

# https://www.philschmid.de/dpo-align-llms-in-2024-with-trl
# https://huggingface.co/docs/trl/main/en/dpo_trainer
# https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments
args = DPOConfig(
    output_dir="/mnt/hdd0/Kaggle/aimo/experiments/18_dpo/03_4_epochs",               # directory to save and repository id
    num_train_epochs=4,                     # number of training epochs
    per_device_train_batch_size=2,         # batch size per device during training
    gradient_accumulation_steps=8,          # number of steps before performing a backward/update pass
    per_device_eval_batch_size=2,           # batch size for evaluation
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    learning_rate=2e-5,                     # 10x higher LR than QLoRA paper
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.1,                       # warmup ratio based on QLoRA paper
    lr_scheduler_type="cosine",             # use cosine learning rate scheduler
    logging_steps=10,
    save_steps=100,                         # when to save checkpoint
    # save_total_limit=2,                     # limit the total amount of checkpoints
    eval_strategy="steps",                  # evaluate every n steps
    eval_steps=100,                         # when to evaluate
    bf16=True,                              # use bfloat16 precision
    # tf32=True,                              # use tf32 precision
    push_to_hub=False,                      # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    model_init_kwargs=None,
    max_length=1024,
    max_prompt_length=300,
    beta=0.1,                               # same as huggingface documentation, default value
    loss_type="sigmoid",                    # default value in huggingface documentation
)

trainer = DPOTrainer(
    model,
    ref_model=None, # set to none since we use peft
    peft_config=peft_config,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)
trainer.train()

### 4. Use V1 dataset that has python block on prompt

In [16]:
# LoRA config based on https://github.com/ironbar/prompt_recovery/blob/main/notebooks/014_fine-tune_mistral_v2.ipynb
peft_config = LoraConfig(
        lora_alpha=64,
        lora_dropout=0.05, # 0.1, althought Vaca suggested to use 0.05 for big models
        # r: the rank of the update matrices, expressed in int. Lower rank results in smaller update matrices with fewer trainable parameters
        r=16,
        bias="none",
        # target_modules: The modules (for example, attention blocks) to apply the LoRA update matrices.
        target_modules="all-linear",
        task_type="CAUSAL_LM",
)

# https://www.philschmid.de/dpo-align-llms-in-2024-with-trl
# https://huggingface.co/docs/trl/main/en/dpo_trainer
# https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments
args = DPOConfig(
    output_dir="/mnt/hdd0/Kaggle/aimo/experiments/18_dpo/06_v1_dataset",               # directory to save and repository id
    num_train_epochs=4,                     # number of training epochs
    per_device_train_batch_size=2,         # batch size per device during training
    gradient_accumulation_steps=8,          # number of steps before performing a backward/update pass
    per_device_eval_batch_size=2,           # batch size for evaluation
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    learning_rate=2e-5,                     # 10x higher LR than QLoRA paper
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_steps=50,                       # warmup ratio based on QLoRA paper
    lr_scheduler_type="cosine",             # use cosine learning rate scheduler
    logging_steps=10,
    save_steps=100,                         # when to save checkpoint
    # save_total_limit=2,                     # limit the total amount of checkpoints
    eval_strategy="steps",                  # evaluate every n steps
    eval_steps=100,                         # when to evaluate
    bf16=True,                              # use bfloat16 precision
    # tf32=True,                              # use tf32 precision
    push_to_hub=False,                      # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    model_init_kwargs=None,
    max_length=1024,
    max_prompt_length=300,
    beta=0.1,                               # same as huggingface documentation, default value
    loss_type="sigmoid",                    # default value in huggingface documentation
)

trainer = DPOTrainer(
    model,
    ref_model=None, # set to none since we use peft
    peft_config=peft_config,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)
trainer.train()



Map:   0%|          | 0/10217 [00:00<?, ? examples/s]

Map:   0%|          | 0/444 [00:00<?, ? examples/s]

  0%|          | 0/2552 [00:00<?, ?it/s]

Could not estimate the number of tokens of the input, floating-point operations will not be computed


{'loss': 0.6953, 'grad_norm': 3.421875, 'learning_rate': 4.000000000000001e-06, 'rewards/chosen': -0.0010355999693274498, 'rewards/rejected': 0.0015163278440013528, 'rewards/accuracies': 0.3812499940395355, 'rewards/margins': -0.002551927464082837, 'logps/rejected': -89.96470642089844, 'logps/chosen': -73.45173645019531, 'logits/rejected': 12.976682662963867, 'logits/chosen': 11.250761032104492, 'epoch': 0.02}
{'loss': 0.6911, 'grad_norm': 3.484375, 'learning_rate': 8.000000000000001e-06, 'rewards/chosen': 0.004204194992780685, 'rewards/rejected': -0.0026238346472382545, 'rewards/accuracies': 0.5249999761581421, 'rewards/margins': 0.00682802964001894, 'logps/rejected': -93.73872375488281, 'logps/chosen': -75.77989959716797, 'logits/rejected': 12.580449104309082, 'logits/chosen': 11.264778137207031, 'epoch': 0.03}
{'loss': 0.6933, 'grad_norm': 3.8125, 'learning_rate': 1.2e-05, 'rewards/chosen': -0.007909061387181282, 'rewards/rejected': -0.010043960995972157, 'rewards/accuracies': 0.493

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.6154296398162842, 'eval_runtime': 241.8181, 'eval_samples_per_second': 1.836, 'eval_steps_per_second': 0.918, 'eval_rewards/chosen': 0.21208280324935913, 'eval_rewards/rejected': -0.014565981924533844, 'eval_rewards/accuracies': 0.7094594836235046, 'eval_rewards/margins': 0.22664877772331238, 'eval_logps/rejected': -90.63356018066406, 'eval_logps/chosen': -71.53353118896484, 'eval_logits/rejected': 12.089920043945312, 'eval_logits/chosen': 10.496620178222656, 'epoch': 0.16}




{'loss': 0.5609, 'grad_norm': 3.640625, 'learning_rate': 1.9971634384234003e-05, 'rewards/chosen': 0.35068756341934204, 'rewards/rejected': -0.047439176589250565, 'rewards/accuracies': 0.706250011920929, 'rewards/margins': 0.3981267511844635, 'logps/rejected': -91.0195541381836, 'logps/chosen': -68.90591430664062, 'logits/rejected': 11.695329666137695, 'logits/chosen': 10.039606094360352, 'epoch': 0.17}
{'loss': 0.5268, 'grad_norm': 4.75, 'learning_rate': 1.996139783974652e-05, 'rewards/chosen': 0.4217213988304138, 'rewards/rejected': -0.1326034516096115, 'rewards/accuracies': 0.824999988079071, 'rewards/margins': 0.5543248057365417, 'logps/rejected': -88.86412048339844, 'logps/chosen': -67.71923065185547, 'logits/rejected': 10.790011405944824, 'logits/chosen': 9.240800857543945, 'epoch': 0.19}
{'loss': 0.4844, 'grad_norm': 3.6875, 'learning_rate': 1.9949590788846255e-05, 'rewards/chosen': 0.46751540899276733, 'rewards/rejected': -0.25762778520584106, 'rewards/accuracies': 0.7937499880

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.5255338549613953, 'eval_runtime': 241.9775, 'eval_samples_per_second': 1.835, 'eval_steps_per_second': 0.917, 'eval_rewards/chosen': -0.19891761243343353, 'eval_rewards/rejected': -1.0811268091201782, 'eval_rewards/accuracies': 0.7252252101898193, 'eval_rewards/margins': 0.8822092413902283, 'eval_logps/rejected': -101.29917907714844, 'eval_logps/chosen': -75.64353942871094, 'eval_logits/rejected': 10.240870475769043, 'eval_logits/chosen': 8.471366882324219, 'epoch': 0.31}




{'loss': 0.358, 'grad_norm': 5.09375, 'learning_rate': 1.9798871373098845e-05, 'rewards/chosen': 0.261913925409317, 'rewards/rejected': -1.1730979681015015, 'rewards/accuracies': 0.8500000238418579, 'rewards/margins': 1.435011863708496, 'logps/rejected': -98.83033752441406, 'logps/chosen': -67.28208923339844, 'logits/rejected': 9.855779647827148, 'logits/chosen': 7.917764186859131, 'epoch': 0.33}
{'loss': 0.3286, 'grad_norm': 4.59375, 'learning_rate': 1.9773043129309123e-05, 'rewards/chosen': 0.1862579882144928, 'rewards/rejected': -1.662421464920044, 'rewards/accuracies': 0.893750011920929, 'rewards/margins': 1.8486793041229248, 'logps/rejected': -110.51202392578125, 'logps/chosen': -69.61744689941406, 'logits/rejected': 10.02865982055664, 'logits/chosen': 8.087186813354492, 'epoch': 0.34}
{'loss': 0.3816, 'grad_norm': 5.71875, 'learning_rate': 1.974567407496712e-05, 'rewards/chosen': 0.012288955971598625, 'rewards/rejected': -1.5592478513717651, 'rewards/accuracies': 0.83125001192092

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.5561589002609253, 'eval_runtime': 241.8968, 'eval_samples_per_second': 1.835, 'eval_steps_per_second': 0.918, 'eval_rewards/chosen': -0.7597678303718567, 'eval_rewards/rejected': -1.8409866094589233, 'eval_rewards/accuracies': 0.7207207083702087, 'eval_rewards/margins': 1.0812187194824219, 'eval_logps/rejected': -108.89778900146484, 'eval_logps/chosen': -81.25204467773438, 'eval_logits/rejected': 8.22982120513916, 'eval_logits/chosen': 6.488012790679932, 'epoch': 0.47}




{'loss': 0.2253, 'grad_norm': 7.4375, 'learning_rate': 1.947182094039668e-05, 'rewards/chosen': 0.17852099239826202, 'rewards/rejected': -2.288707733154297, 'rewards/accuracies': 0.918749988079071, 'rewards/margins': 2.467228889465332, 'logps/rejected': -105.67063903808594, 'logps/chosen': -70.11473083496094, 'logits/rejected': 7.758152008056641, 'logits/chosen': 6.562679290771484, 'epoch': 0.49}
{'loss': 0.2148, 'grad_norm': 4.875, 'learning_rate': 1.9430807674052092e-05, 'rewards/chosen': 0.2110137939453125, 'rewards/rejected': -2.6799962520599365, 'rewards/accuracies': 0.9312499761581421, 'rewards/margins': 2.891010046005249, 'logps/rejected': -115.49813079833984, 'logps/chosen': -67.47633361816406, 'logits/rejected': 7.4566850662231445, 'logits/chosen': 5.550971984863281, 'epoch': 0.5}
{'loss': 0.2884, 'grad_norm': 9.5625, 'learning_rate': 1.9388307553737006e-05, 'rewards/chosen': 0.004266184754669666, 'rewards/rejected': -2.7237489223480225, 'rewards/accuracies': 0.912500023841857

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.6126380562782288, 'eval_runtime': 242.0941, 'eval_samples_per_second': 1.834, 'eval_steps_per_second': 0.917, 'eval_rewards/chosen': -1.1078107357025146, 'eval_rewards/rejected': -2.3548102378845215, 'eval_rewards/accuracies': 0.7184684872627258, 'eval_rewards/margins': 1.2469996213912964, 'eval_logps/rejected': -114.0360107421875, 'eval_logps/chosen': -84.73246765136719, 'eval_logits/rejected': 6.159877300262451, 'eval_logits/chosen': 4.411309242248535, 'epoch': 0.63}




{'loss': 0.2899, 'grad_norm': 11.25, 'learning_rate': 1.899563263509725e-05, 'rewards/chosen': 0.04579009860754013, 'rewards/rejected': -2.864750385284424, 'rewards/accuracies': 0.893750011920929, 'rewards/margins': 2.9105403423309326, 'logps/rejected': -124.10768127441406, 'logps/chosen': -72.30010223388672, 'logits/rejected': 5.737375736236572, 'logits/chosen': 3.709618330001831, 'epoch': 0.64}
{'loss': 0.2374, 'grad_norm': 6.03125, 'learning_rate': 1.89400801176212e-05, 'rewards/chosen': 0.2087489366531372, 'rewards/rejected': -2.9790165424346924, 'rewards/accuracies': 0.8999999761581421, 'rewards/margins': 3.187765598297119, 'logps/rejected': -115.80079650878906, 'logps/chosen': -70.25961303710938, 'logits/rejected': 6.558821201324463, 'logits/chosen': 4.50593376159668, 'epoch': 0.66}
{'loss': 0.1514, 'grad_norm': 8.5625, 'learning_rate': 1.8883118113908243e-05, 'rewards/chosen': 0.24212777614593506, 'rewards/rejected': -3.027545928955078, 'rewards/accuracies': 0.949999988079071, '

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.6795856356620789, 'eval_runtime': 241.8934, 'eval_samples_per_second': 1.836, 'eval_steps_per_second': 0.918, 'eval_rewards/chosen': -1.8795621395111084, 'eval_rewards/rejected': -3.12912917137146, 'eval_rewards/accuracies': 0.684684693813324, 'eval_rewards/margins': 1.2495671510696411, 'eval_logps/rejected': -121.77919006347656, 'eval_logps/chosen': -92.4499740600586, 'eval_logits/rejected': 6.320049285888672, 'eval_logits/chosen': 4.5810346603393555, 'epoch': 0.78}




{'loss': 0.182, 'grad_norm': 5.59375, 'learning_rate': 1.8377804245773094e-05, 'rewards/chosen': -0.39600902795791626, 'rewards/rejected': -4.173902988433838, 'rewards/accuracies': 0.956250011920929, 'rewards/margins': 3.7778942584991455, 'logps/rejected': -123.6973648071289, 'logps/chosen': -76.1230239868164, 'logits/rejected': 6.39043664932251, 'logits/chosen': 5.02389669418335, 'epoch': 0.8}
{'loss': 0.0997, 'grad_norm': 4.65625, 'learning_rate': 1.8308587175317708e-05, 'rewards/chosen': -0.1254158914089203, 'rewards/rejected': -4.295899868011475, 'rewards/accuracies': 0.981249988079071, 'rewards/margins': 4.170483589172363, 'logps/rejected': -132.63800048828125, 'logps/chosen': -73.69773864746094, 'logits/rejected': 6.727433204650879, 'logits/chosen': 4.744823932647705, 'epoch': 0.81}
{'loss': 0.0916, 'grad_norm': 7.71875, 'learning_rate': 1.823806017932276e-05, 'rewards/chosen': -0.08714324980974197, 'rewards/rejected': -4.550434589385986, 'rewards/accuracies': 0.981249988079071, 

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.6472303867340088, 'eval_runtime': 241.958, 'eval_samples_per_second': 1.835, 'eval_steps_per_second': 0.918, 'eval_rewards/chosen': -2.109342098236084, 'eval_rewards/rejected': -3.6435024738311768, 'eval_rewards/accuracies': 0.7139639854431152, 'eval_rewards/margins': 1.5341600179672241, 'eval_logps/rejected': -126.92292785644531, 'eval_logps/chosen': -94.7477798461914, 'eval_logits/rejected': 5.0846452713012695, 'eval_logits/chosen': 3.3640599250793457, 'epoch': 0.94}




{'loss': 0.1042, 'grad_norm': 11.5, 'learning_rate': 1.762806374471105e-05, 'rewards/chosen': -0.4803459048271179, 'rewards/rejected': -5.247006416320801, 'rewards/accuracies': 0.9750000238418579, 'rewards/margins': 4.766659736633301, 'logps/rejected': -138.08908081054688, 'logps/chosen': -79.90936279296875, 'logits/rejected': 5.17446231842041, 'logits/chosen': 3.205935001373291, 'epoch': 0.96}
{'loss': 0.1424, 'grad_norm': 2.546875, 'learning_rate': 1.7546271973660577e-05, 'rewards/chosen': -0.7516440153121948, 'rewards/rejected': -5.367527961730957, 'rewards/accuracies': 0.9437500238418579, 'rewards/margins': 4.615883827209473, 'logps/rejected': -143.3642120361328, 'logps/chosen': -82.21552276611328, 'logits/rejected': 4.980582237243652, 'logits/chosen': 3.2851309776306152, 'epoch': 0.97}
{'loss': 0.1092, 'grad_norm': 3.140625, 'learning_rate': 1.746329046310588e-05, 'rewards/chosen': -0.8194684982299805, 'rewards/rejected': -5.632216453552246, 'rewards/accuracies': 0.949999988079071

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.7407044172286987, 'eval_runtime': 241.9236, 'eval_samples_per_second': 1.835, 'eval_steps_per_second': 0.918, 'eval_rewards/chosen': -3.395005702972412, 'eval_rewards/rejected': -5.643986225128174, 'eval_rewards/accuracies': 0.7567567825317383, 'eval_rewards/margins': 2.2489802837371826, 'eval_logps/rejected': -146.92776489257812, 'eval_logps/chosen': -107.60442352294922, 'eval_logits/rejected': 1.5974797010421753, 'eval_logits/chosen': -0.047791413962841034, 'epoch': 1.1}




{'loss': 0.0511, 'grad_norm': 1.9375, 'learning_rate': 1.6758216116827106e-05, 'rewards/chosen': -0.8284792900085449, 'rewards/rejected': -7.220571994781494, 'rewards/accuracies': 0.987500011920929, 'rewards/margins': 6.392092704772949, 'logps/rejected': -152.17739868164062, 'logps/chosen': -77.2468490600586, 'logits/rejected': 0.6463303565979004, 'logits/chosen': -0.9613253474235535, 'epoch': 1.11}
{'loss': 0.0751, 'grad_norm': 2.09375, 'learning_rate': 1.6665137491605924e-05, 'rewards/chosen': -0.9346145391464233, 'rewards/rejected': -7.900615692138672, 'rewards/accuracies': 0.9750000238418579, 'rewards/margins': 6.966001033782959, 'logps/rejected': -164.94879150390625, 'logps/chosen': -79.9521255493164, 'logits/rejected': 1.2978918552398682, 'logits/chosen': -0.7608845829963684, 'epoch': 1.13}
{'loss': 0.1016, 'grad_norm': 1.453125, 'learning_rate': 1.6571008045873305e-05, 'rewards/chosen': -1.2113415002822876, 'rewards/rejected': -7.9060516357421875, 'rewards/accuracies': 0.96875, 

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.8031892776489258, 'eval_runtime': 241.5978, 'eval_samples_per_second': 1.838, 'eval_steps_per_second': 0.919, 'eval_rewards/chosen': -4.136497497558594, 'eval_rewards/rejected': -6.403458118438721, 'eval_rewards/accuracies': 0.727477490901947, 'eval_rewards/margins': 2.2669601440429688, 'eval_logps/rejected': -154.52249145507812, 'eval_logps/chosen': -115.01934051513672, 'eval_logits/rejected': 0.3165725767612457, 'eval_logits/chosen': -1.3544776439666748, 'epoch': 1.25}




{'loss': 0.0438, 'grad_norm': 3.625, 'learning_rate': 1.5781957485128947e-05, 'rewards/chosen': -1.5442534685134888, 'rewards/rejected': -8.52992057800293, 'rewards/accuracies': 0.987500011920929, 'rewards/margins': 6.985666751861572, 'logps/rejected': -171.83563232421875, 'logps/chosen': -88.0844497680664, 'logits/rejected': 0.35517340898513794, 'logits/chosen': -1.5864067077636719, 'epoch': 1.27}
{'loss': 0.0672, 'grad_norm': 9.125, 'learning_rate': 1.5679057568508683e-05, 'rewards/chosen': -1.5455121994018555, 'rewards/rejected': -8.769659042358398, 'rewards/accuracies': 0.9750000238418579, 'rewards/margins': 7.224146842956543, 'logps/rejected': -173.4645233154297, 'logps/chosen': -91.39881896972656, 'logits/rejected': -1.1188387870788574, 'logits/chosen': -2.5360589027404785, 'epoch': 1.28}
{'loss': 0.043, 'grad_norm': 2.328125, 'learning_rate': 1.557526229598824e-05, 'rewards/chosen': -1.1841191053390503, 'rewards/rejected': -8.379477500915527, 'rewards/accuracies': 0.993749976158

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.8447271585464478, 'eval_runtime': 241.9125, 'eval_samples_per_second': 1.835, 'eval_steps_per_second': 0.918, 'eval_rewards/chosen': -4.706053733825684, 'eval_rewards/rejected': -7.380492687225342, 'eval_rewards/accuracies': 0.7297297120094299, 'eval_rewards/margins': 2.674437999725342, 'eval_logps/rejected': -164.29283142089844, 'eval_logps/chosen': -120.71488952636719, 'eval_logits/rejected': -1.6792519092559814, 'eval_logits/chosen': -3.344982862472534, 'epoch': 1.41}




{'loss': 0.0842, 'grad_norm': 0.8671875, 'learning_rate': 1.4714659459400197e-05, 'rewards/chosen': -1.4299651384353638, 'rewards/rejected': -9.768987655639648, 'rewards/accuracies': 0.981249988079071, 'rewards/margins': 8.33902359008789, 'logps/rejected': -187.774658203125, 'logps/chosen': -86.75746154785156, 'logits/rejected': -0.9985278248786926, 'logits/chosen': -3.070699453353882, 'epoch': 1.42}
{'loss': 0.0093, 'grad_norm': 0.84765625, 'learning_rate': 1.460355845458695e-05, 'rewards/chosen': -1.245634913444519, 'rewards/rejected': -9.713510513305664, 'rewards/accuracies': 1.0, 'rewards/margins': 8.467874526977539, 'logps/rejected': -185.9691162109375, 'logps/chosen': -80.33207702636719, 'logits/rejected': -1.2590669393539429, 'logits/chosen': -2.608199119567871, 'epoch': 1.44}
{'loss': 0.035, 'grad_norm': 0.341796875, 'learning_rate': 1.4491731656246444e-05, 'rewards/chosen': -1.268381953239441, 'rewards/rejected': -9.527204513549805, 'rewards/accuracies': 0.987500011920929, 're

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.7890896201133728, 'eval_runtime': 241.9429, 'eval_samples_per_second': 1.835, 'eval_steps_per_second': 0.918, 'eval_rewards/chosen': -4.738537311553955, 'eval_rewards/rejected': -7.445478916168213, 'eval_rewards/accuracies': 0.7477477192878723, 'eval_rewards/margins': 2.7069408893585205, 'eval_logps/rejected': -164.9426727294922, 'eval_logps/chosen': -121.03974151611328, 'eval_logits/rejected': -2.167969226837158, 'eval_logits/chosen': -3.828482151031494, 'epoch': 1.57}




{'loss': 0.0759, 'grad_norm': 0.126953125, 'learning_rate': 1.3573127103628666e-05, 'rewards/chosen': -1.1793973445892334, 'rewards/rejected': -9.332136154174805, 'rewards/accuracies': 0.981249988079071, 'rewards/margins': 8.152738571166992, 'logps/rejected': -182.75967407226562, 'logps/chosen': -85.06243896484375, 'logits/rejected': -2.7352347373962402, 'logits/chosen': -3.8798813819885254, 'epoch': 1.58}
{'loss': 0.0268, 'grad_norm': 11.4375, 'learning_rate': 1.345557434347042e-05, 'rewards/chosen': -1.4066097736358643, 'rewards/rejected': -9.400697708129883, 'rewards/accuracies': 0.981249988079071, 'rewards/margins': 7.994089603424072, 'logps/rejected': -179.3679656982422, 'logps/chosen': -87.85325622558594, 'logits/rejected': -2.366121530532837, 'logits/chosen': -3.6166083812713623, 'epoch': 1.6}
{'loss': 0.0324, 'grad_norm': 0.67578125, 'learning_rate': 1.3337476780087407e-05, 'rewards/chosen': -1.3526214361190796, 'rewards/rejected': -10.408535957336426, 'rewards/accuracies': 0.9

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.8847374320030212, 'eval_runtime': 241.6942, 'eval_samples_per_second': 1.837, 'eval_steps_per_second': 0.919, 'eval_rewards/chosen': -5.3233323097229, 'eval_rewards/rejected': -8.077752113342285, 'eval_rewards/accuracies': 0.7184684872627258, 'eval_rewards/margins': 2.754420042037964, 'eval_logps/rejected': -171.2654266357422, 'eval_logps/chosen': -126.88768768310547, 'eval_logits/rejected': -1.7533949613571167, 'eval_logits/chosen': -3.4619107246398926, 'epoch': 1.72}




{'loss': 0.094, 'grad_norm': 1.796875, 'learning_rate': 1.2375334333084932e-05, 'rewards/chosen': -1.6035562753677368, 'rewards/rejected': -10.377058029174805, 'rewards/accuracies': 0.96875, 'rewards/margins': 8.773500442504883, 'logps/rejected': -193.09446716308594, 'logps/chosen': -86.76268768310547, 'logits/rejected': -1.6796293258666992, 'logits/chosen': -3.6540684700012207, 'epoch': 1.74}
{'loss': 0.1447, 'grad_norm': 7.1875, 'learning_rate': 1.225318073607753e-05, 'rewards/chosen': -1.5598177909851074, 'rewards/rejected': -10.410630226135254, 'rewards/accuracies': 0.9937499761581421, 'rewards/margins': 8.850812911987305, 'logps/rejected': -196.73187255859375, 'logps/chosen': -89.08562469482422, 'logits/rejected': -2.2773849964141846, 'logits/chosen': -3.9072132110595703, 'epoch': 1.75}
{'loss': 0.0107, 'grad_norm': 1.3359375, 'learning_rate': 1.2130671904307692e-05, 'rewards/chosen': -0.9228496551513672, 'rewards/rejected': -9.89411449432373, 'rewards/accuracies': 1.0, 'rewards/m

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.8103374242782593, 'eval_runtime': 241.7327, 'eval_samples_per_second': 1.837, 'eval_steps_per_second': 0.918, 'eval_rewards/chosen': -5.39689826965332, 'eval_rewards/rejected': -8.190281867980957, 'eval_rewards/accuracies': 0.727477490901947, 'eval_rewards/margins': 2.7933828830718994, 'eval_logps/rejected': -172.39071655273438, 'eval_logps/chosen': -127.62334442138672, 'eval_logits/rejected': -2.213864803314209, 'eval_logits/chosen': -3.8715150356292725, 'epoch': 1.88}




{'loss': 0.0646, 'grad_norm': 4.8125, 'learning_rate': 1.1140140907337437e-05, 'rewards/chosen': -1.4005401134490967, 'rewards/rejected': -10.798044204711914, 'rewards/accuracies': 0.981249988079071, 'rewards/margins': 9.397503852844238, 'logps/rejected': -198.8382568359375, 'logps/chosen': -85.65777587890625, 'logits/rejected': -1.882615327835083, 'logits/chosen': -3.461836576461792, 'epoch': 1.89}
{'loss': 0.0166, 'grad_norm': 6.09375, 'learning_rate': 1.1015309834121083e-05, 'rewards/chosen': -1.4918826818466187, 'rewards/rejected': -11.159431457519531, 'rewards/accuracies': 0.9937499761581421, 'rewards/margins': 9.667550086975098, 'logps/rejected': -197.66485595703125, 'logps/chosen': -82.3350830078125, 'logits/rejected': -2.545574188232422, 'logits/chosen': -4.18272066116333, 'epoch': 1.91}
{'loss': 0.0437, 'grad_norm': 0.53125, 'learning_rate': 1.0890318687927912e-05, 'rewards/chosen': -1.2422250509262085, 'rewards/rejected': -10.17741870880127, 'rewards/accuracies': 0.9812499880

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.8374937772750854, 'eval_runtime': 241.7384, 'eval_samples_per_second': 1.837, 'eval_steps_per_second': 0.918, 'eval_rewards/chosen': -5.504368305206299, 'eval_rewards/rejected': -8.476333618164062, 'eval_rewards/accuracies': 0.7454954981803894, 'eval_rewards/margins': 2.9719650745391846, 'eval_logps/rejected': -175.25123596191406, 'eval_logps/chosen': -128.6980438232422, 'eval_logits/rejected': -2.839756965637207, 'eval_logits/chosen': -4.508242130279541, 'epoch': 2.04}




{'loss': 0.0021, 'grad_norm': 0.027099609375, 'learning_rate': 9.886995475270205e-06, 'rewards/chosen': -1.1949118375778198, 'rewards/rejected': -11.197511672973633, 'rewards/accuracies': 1.0, 'rewards/margins': 10.00260066986084, 'logps/rejected': -200.86715698242188, 'logps/chosen': -83.22762298583984, 'logits/rejected': -2.768893241882324, 'logits/chosen': -4.5776567459106445, 'epoch': 2.05}
{'loss': 0.024, 'grad_norm': 0.025146484375, 'learning_rate': 9.761452444493389e-06, 'rewards/chosen': -1.3402456045150757, 'rewards/rejected': -11.453248023986816, 'rewards/accuracies': 0.987500011920929, 'rewards/margins': 10.113001823425293, 'logps/rejected': -200.0578155517578, 'logps/chosen': -86.02039337158203, 'logits/rejected': -2.37975811958313, 'logits/chosen': -4.374926567077637, 'epoch': 2.07}
{'loss': 0.0079, 'grad_norm': 0.39453125, 'learning_rate': 9.635947022942876e-06, 'rewards/chosen': -1.4709943532943726, 'rewards/rejected': -11.062249183654785, 'rewards/accuracies': 0.9937499

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.8747154474258423, 'eval_runtime': 241.3682, 'eval_samples_per_second': 1.84, 'eval_steps_per_second': 0.92, 'eval_rewards/chosen': -6.028387546539307, 'eval_rewards/rejected': -9.119328498840332, 'eval_rewards/accuracies': 0.7364864945411682, 'eval_rewards/margins': 3.0909409523010254, 'eval_logps/rejected': -181.68118286132812, 'eval_logps/chosen': -133.938232421875, 'eval_logits/rejected': -3.388364553451538, 'eval_logits/chosen': -5.05573844909668, 'epoch': 2.19}




{'loss': 0.011, 'grad_norm': 0.1005859375, 'learning_rate': 8.635629347786338e-06, 'rewards/chosen': -1.7463159561157227, 'rewards/rejected': -11.586782455444336, 'rewards/accuracies': 0.9937499761581421, 'rewards/margins': 9.84046745300293, 'logps/rejected': -201.5366973876953, 'logps/chosen': -90.95205688476562, 'logits/rejected': -2.6847786903381348, 'logits/chosen': -4.293046951293945, 'epoch': 2.21}
{'loss': 0.0019, 'grad_norm': 0.040771484375, 'learning_rate': 8.511351088173904e-06, 'rewards/chosen': -2.079502582550049, 'rewards/rejected': -12.395157814025879, 'rewards/accuracies': 1.0, 'rewards/margins': 10.315653800964355, 'logps/rejected': -211.43338012695312, 'logps/chosen': -97.79947662353516, 'logits/rejected': -3.8358092308044434, 'logits/chosen': -5.289044380187988, 'epoch': 2.22}
{'loss': 0.0067, 'grad_norm': 2.46875, 'learning_rate': 8.38730752781754e-06, 'rewards/chosen': -1.7524598836898804, 'rewards/rejected': -11.588372230529785, 'rewards/accuracies': 1.0, 'rewards/

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.8984469771385193, 'eval_runtime': 241.6385, 'eval_samples_per_second': 1.837, 'eval_steps_per_second': 0.919, 'eval_rewards/chosen': -6.081781387329102, 'eval_rewards/rejected': -9.154696464538574, 'eval_rewards/accuracies': 0.727477490901947, 'eval_rewards/margins': 3.0729153156280518, 'eval_logps/rejected': -182.03485107421875, 'eval_logps/chosen': -134.47216796875, 'eval_logits/rejected': -3.9180545806884766, 'eval_logits/chosen': -5.556198596954346, 'epoch': 2.35}




{'loss': 0.0039, 'grad_norm': 1.2421875, 'learning_rate': 7.405745819877117e-06, 'rewards/chosen': -1.3962681293487549, 'rewards/rejected': -11.849300384521484, 'rewards/accuracies': 1.0, 'rewards/margins': 10.453032493591309, 'logps/rejected': -200.9911651611328, 'logps/chosen': -88.87690734863281, 'logits/rejected': -4.332930088043213, 'logits/chosen': -6.0395612716674805, 'epoch': 2.36}
{'loss': 0.0028, 'grad_norm': 0.08837890625, 'learning_rate': 7.284689145790879e-06, 'rewards/chosen': -1.565704107284546, 'rewards/rejected': -11.247880935668945, 'rewards/accuracies': 1.0, 'rewards/margins': 9.682177543640137, 'logps/rejected': -192.40196228027344, 'logps/chosen': -88.87049865722656, 'logits/rejected': -4.715418815612793, 'logits/chosen': -6.456287384033203, 'epoch': 2.38}
{'loss': 0.0017, 'grad_norm': 0.251953125, 'learning_rate': 7.164060565550287e-06, 'rewards/chosen': -1.5877081155776978, 'rewards/rejected': -11.789549827575684, 'rewards/accuracies': 1.0, 'rewards/margins': 10.

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.9451972842216492, 'eval_runtime': 241.6582, 'eval_samples_per_second': 1.837, 'eval_steps_per_second': 0.919, 'eval_rewards/chosen': -6.297119140625, 'eval_rewards/rejected': -9.26965618133545, 'eval_rewards/accuracies': 0.7252252101898193, 'eval_rewards/margins': 2.9725375175476074, 'eval_logps/rejected': -183.18447875976562, 'eval_logps/chosen': -136.6255645751953, 'eval_logits/rejected': -3.7901370525360107, 'eval_logits/chosen': -5.431327819824219, 'epoch': 2.51}




{'loss': 0.0015, 'grad_norm': 2.4375, 'learning_rate': 6.2167099338136095e-06, 'rewards/chosen': -1.613863229751587, 'rewards/rejected': -11.89138126373291, 'rewards/accuracies': 1.0, 'rewards/margins': 10.277518272399902, 'logps/rejected': -201.74620056152344, 'logps/chosen': -89.22239685058594, 'logits/rejected': -3.54023814201355, 'logits/chosen': -5.420180320739746, 'epoch': 2.52}
{'loss': 0.0754, 'grad_norm': 0.107421875, 'learning_rate': 6.10078093437313e-06, 'rewards/chosen': -2.034450054168701, 'rewards/rejected': -12.459837913513184, 'rewards/accuracies': 0.9937499761581421, 'rewards/margins': 10.42538833618164, 'logps/rejected': -217.42745971679688, 'logps/chosen': -93.59658813476562, 'logits/rejected': -3.328134536743164, 'logits/chosen': -4.7314863204956055, 'epoch': 2.54}
{'loss': 0.0013, 'grad_norm': 0.0194091796875, 'learning_rate': 5.985466682847141e-06, 'rewards/chosen': -1.9505598545074463, 'rewards/rejected': -12.063024520874023, 'rewards/accuracies': 1.0, 'rewards/m

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.9280789494514465, 'eval_runtime': 241.3437, 'eval_samples_per_second': 1.84, 'eval_steps_per_second': 0.92, 'eval_rewards/chosen': -6.223053455352783, 'eval_rewards/rejected': -9.315306663513184, 'eval_rewards/accuracies': 0.7319819927215576, 'eval_rewards/margins': 3.0922529697418213, 'eval_logps/rejected': -183.64097595214844, 'eval_logps/chosen': -135.88490295410156, 'eval_logits/rejected': -3.878253221511841, 'eval_logits/chosen': -5.5339202880859375, 'epoch': 2.66}




{'loss': 0.0017, 'grad_norm': 0.310546875, 'learning_rate': 5.087243568272078e-06, 'rewards/chosen': -1.828168511390686, 'rewards/rejected': -12.077177047729492, 'rewards/accuracies': 1.0, 'rewards/margins': 10.249009132385254, 'logps/rejected': -205.72647094726562, 'logps/chosen': -92.82276153564453, 'logits/rejected': -3.938037395477295, 'logits/chosen': -5.904210567474365, 'epoch': 2.68}
{'loss': 0.0033, 'grad_norm': 0.275390625, 'learning_rate': 4.978267595166084e-06, 'rewards/chosen': -1.7578035593032837, 'rewards/rejected': -12.09254264831543, 'rewards/accuracies': 1.0, 'rewards/margins': 10.334739685058594, 'logps/rejected': -213.3328399658203, 'logps/chosen': -86.18280029296875, 'logits/rejected': -3.532566785812378, 'logits/chosen': -5.414917945861816, 'epoch': 2.69}
{'loss': 0.0038, 'grad_norm': 0.035400390625, 'learning_rate': 4.870083344574531e-06, 'rewards/chosen': -1.5474770069122314, 'rewards/rejected': -12.073423385620117, 'rewards/accuracies': 1.0, 'rewards/margins': 1

  0%|          | 0/222 [00:00<?, ?it/s]

{'eval_loss': 0.9329873919487, 'eval_runtime': 241.5419, 'eval_samples_per_second': 1.838, 'eval_steps_per_second': 0.919, 'eval_rewards/chosen': -6.304708003997803, 'eval_rewards/rejected': -9.430541038513184, 'eval_rewards/accuracies': 0.727477490901947, 'eval_rewards/margins': 3.12583327293396, 'eval_logps/rejected': -184.7933349609375, 'eval_logps/chosen': -136.70143127441406, 'eval_logits/rejected': -3.997708320617676, 'eval_logits/chosen': -5.640864372253418, 'epoch': 2.82}




{'loss': 0.009, 'grad_norm': 0.2490234375, 'learning_rate': 4.03513065434528e-06, 'rewards/chosen': -1.7783300876617432, 'rewards/rejected': -13.076261520385742, 'rewards/accuracies': 0.9937499761581421, 'rewards/margins': 11.297931671142578, 'logps/rejected': -223.69815063476562, 'logps/chosen': -88.68041229248047, 'logits/rejected': -4.250003814697266, 'logits/chosen': -5.50979471206665, 'epoch': 2.83}
{'loss': 0.0084, 'grad_norm': 0.142578125, 'learning_rate': 3.934823580888473e-06, 'rewards/chosen': -1.712306261062622, 'rewards/rejected': -11.985713958740234, 'rewards/accuracies': 0.9937499761581421, 'rewards/margins': 10.273408889770508, 'logps/rejected': -202.80624389648438, 'logps/chosen': -87.18635559082031, 'logits/rejected': -3.3614087104797363, 'logits/chosen': -4.8861589431762695, 'epoch': 2.85}
{'loss': 0.0045, 'grad_norm': 0.036865234375, 'learning_rate': 3.83547273853638e-06, 'rewards/chosen': -1.8169723749160767, 'rewards/rejected': -12.242486000061035, 'rewards/accurac

KeyboardInterrupt: 

## TODO

- [x] Count the tokens in the train set, max length and max prompt length
- [x] Do the training parameters from the example have sense?
- [x] Shuffle the data?
- [x] Verify that the model generates text correctly, on the prompt recovery challenge after fine-tuning the model forgot to use the eos token.