In [38]:
import torch
import sys
sys.path.append('..')
from model.utils import LMHyperParams, SmModel, ModelChoice
from synthetic_data.utils import dictl
from dataset.squad import UltraFeedbackDataModule
from transformers import AutoTokenizer, PreTrainedTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft.tuners.lora.config import LoraConfig
from transformers import TrainingArguments
from trl import DPOTrainer, DPOConfig
from typing import cast
from peft.peft_model import PeftModel
import gc
from torch.amp.autocast_mode import autocast
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
model_id = "qgallouedec/tiny-LlamaForCausalLM-3"
model_id = "meta-llama/Llama-3.2-1B-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
)
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_id)  # type: ignore
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"  # to prevent errors with FA
tokenizer.truncation_side = "left"  # to prevent cutting off last generation

In [27]:
data_module = UltraFeedbackDataModule(2, tokenizer, 1024, 1000, False)
# debugger will fail without this
data_module.num_workers = 1
data_module.setup("fit")

[32m2024-11-24 17:06:11.154[0m | [1mINFO    [0m | [36mdataset.squad[0m:[36msetup[0m:[36m220[0m - [1mLoading dataset for stage fit[0m
[32m2024-11-24 17:06:12.462[0m | [1mINFO    [0m | [36mdataset.squad[0m:[36msetup[0m:[36m224[0m - [1mLoaded dataset with 60917 samples[0m
[32m2024-11-24 17:06:12.568[0m | [1mINFO    [0m | [36mdataset.squad[0m:[36msetup[0m:[36m232[0m - [1mProcessing dataset for stage fit, workers: 1, cache dir dataset_caches/ultrafeedback[0m
Map: 100%|██████████| 900/900 [00:00<00:00, 5918.28 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 4362.43 examples/s]


In [35]:
gc.collect()
torch.cuda.empty_cache()

In [36]:
data_module.train_dataset

Dataset({
    features: ['source', 'prompt', 'chosen', 'chosen-rating', 'chosen-model', 'rejected', 'rejected-rating', 'rejected-model'],
    num_rows: 900
})

In [31]:
# max_prompt_length is the maximum length of the prompt and the max_length is the maximum length of the prompt + chosen or rejected response
prompt_length = 1024
max_seq_length = 1512

peft_config = LoraConfig(
    lora_alpha=128,
    lora_dropout=0.05,
    r=256,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
)

args = DPOConfig(
    output_dir="../outputs",
    num_train_epochs=1,
    per_device_train_batch_size=12,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    learning_rate=5e-5,
    max_grad_norm=0.3,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    logging_steps=25,
    save_steps=500,
    save_total_limit=2,
    evaluation_strategy="steps",
    eval_steps=700,
    bf16=True,
    tf32=True,
    push_to_hub=False,
    report_to="tensorboard",
    # debugger will fail without this
    dataloader_num_workers=0,
    dataset_num_proc=1,
    max_length=max_seq_length,
    max_prompt_length=prompt_length,
    precompute_ref_log_probs=True,
    dataloader_pin_memory=True,
    beta=0.1,
    loss_type="sigmoid",
)


trainer = DPOTrainer(
    model,
    ref_model=None,  # set to none since we use peft
    peft_config=peft_config,
    args=args,
    train_dataset=data_module.train_dataset,
    eval_dataset=data_module.val_dataset,
    tokenizer=tokenizer,  # type: ignore
)

Extracting prompt from train dataset: 100%|██████████| 900/900 [00:00<00:00, 2098.57 examples/s]
Applying chat template to train dataset: 100%|██████████| 900/900 [00:00<00:00, 2347.77 examples/s]
Extracting prompt from eval dataset: 100%|██████████| 100/100 [00:00<00:00, 1998.40 examples/s]
Applying chat template to eval dataset: 100%|██████████| 100/100 [00:00<00:00, 2005.45 examples/s]
Tokenizing train dataset: 100%|██████████| 900/900 [00:02<00:00, 435.83 examples/s]
Tokenizing eval dataset: 100%|██████████| 100/100 [00:00<00:00, 414.63 examples/s]


In [33]:
trainer._peft_has_been_casted_to_bf16

True

In [None]:
from tqdm import tqdm
from trl.trainer.dpo_trainer import PreferenceCollator


def get_sample_wise_metrics(model, batch):
    metrics = {}

    model_output = trainer.concatenated_forward(model, batch)

    # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model
    if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch:
        ref_chosen_logps = batch["ref_chosen_logps"]
        ref_rejected_logps = batch["ref_rejected_logps"]
    else:
        ref_chosen_logps, ref_rejected_logps = trainer.compute_ref_log_probs(batch)

    losses, chosen_rewards, rejected_rewards = trainer.dpo_loss(
        model_output["chosen_logps"],
        model_output["rejected_logps"],
        ref_chosen_logps,
        ref_rejected_logps,
    )
    reward_accuracies = (chosen_rewards > rejected_rewards).float()
    reward_margins = chosen_rewards - rejected_rewards

    metrics = {
        "loss": losses.tolist(),
        "reward_accuracy": reward_accuracies.tolist(),
        "reward_margin": reward_margins.tolist(),
    }

    for k in [
        "chosen_logps",
        "rejected_logps",
    ]:
        metrics[k] = model_output[k].tolist()
    return metrics


outputs = []

# call this to precompute logprobs
trainer.get_train_dataloader()
collator = PreferenceCollator(tokenizer.pad_token_id, "pt")  # type: ignore
batch_size = 1
with torch.no_grad(), autocast("cuda"):
    for batch in tqdm(
        trainer.train_dataset.iter(batch_size=batch_size),
        desc="Precomputing logprobs",
        total=len(trainer.train_dataset) // batch_size,
    ):
        sample_collated = collator(dictl(batch))
        metrics = get_sample_wise_metrics(model, sample_collated)
        for i in range(batch_size):
            out_sample = {
                "prompt": batch["prompt"][i],
                "chosen": batch["chosen"][i],
                "rejected": batch["rejected"][i],
            }
            for k, v in metrics.items():
                if isinstance(v, list):
                    out_sample[k] = v[i]
                else:
                    out_sample[k] = v
            outputs.append(out_sample)

Precomputing logprobs:   0%|          | 0/900 [00:00<?, ?it/s]

Precomputing logprobs:   2%|▏         | 16/900 [00:15<17:27,  1.19s/it]

In [45]:
outputs[0]['metrics']

{'loss': [0.693359375,
  0.693359375,
  0.693359375,
  0.693359375,
  0.693359375,
  0.693359375,
  0.693359375,
  0.693359375],
 'reward_accuracy': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
 'reward_margin': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
 'chosen_logps': [-303.25,
  -659.0,
  -253.125,
  -448.0,
  -324.0,
  -278.25,
  -1229.0,
  -738.5],
 'rejected_logps': [-215.75,
  -521.0,
  -373.75,
  -1092.0,
  -176.5,
  -614.5,
  -357.75,
  -911.0],
 'mean_chosen_logits': 1.333984375,
 'mean_rejected_logits': 1.4111328125}

In [None]:
outputs[0]['metrics']

In [None]:
# plot distributiuon of losses
import matplotlib.pyplot as plt
import numpy as np

for metric in outputs[0]['metrics']:
    values = [x['metrics'][metric] for x in outputs]
    plt.hist(values, bins=50)
    plt.title(f'{metric} distribution')
    plt.show()


In [None]:
trainer.train()