<a href="https://colab.research.google.com/github/david-meltzer/LLMs/blob/main/training/colab/DPO_fine_tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd drive/MyDrive/LLMs/Fine-tuning/DPO

Mounted at /content/drive
/content/drive/MyDrive/LLMs/Fine-tuning/DPO


In [None]:
!pip install peft==0.4.0 -qqq
!pip install bitsandbytes==0.41.1 -qqq
!pip install safetensors>=0.3.1 -qqq
#!pip install -U trl
!pip install wandb -qqq
!pip install tokenizers>=0.13.3 -qqq
!pip install -U transformers -qqq
!pip install accelerate==0.21.0 -qqq
!pip install git+https://github.com/huggingface/trl -qqq

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.3/519.3 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for trl (setup.py) ... [?25l[?25hdone


In [None]:
import os
#from dataclasses import dataclass, field
#from typing import Optional

import warnings
from collections import defaultdict
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import datasets
from datasets import Dataset, load_dataset
import transformers
from transformers import (AutoTokenizer,
                          AutoModelForCausalLM,
                          DataCollator,
                          PreTrainedModel,
                          PreTrainedTokenizerBase,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForLanguageModeling,
                          BitsAndBytesConfig)

from transformers.trainer_callback import TrainerCallback

import gc

import os
from google.colab import runtime
import pandas as pd

import accelerate
import bitsandbytes as bnb
import wandb
from peft import (LoraConfig,
                  get_peft_model,
                  prepare_model_for_kbit_training,
                  PeftModel,
                  PeftConfig)
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datetime import datetime
from huggingface_hub import login

from peft.tuners.lora import LoraLayer

from tqdm import tqdm

import trl
from trl import DPOTrainer
from trl.models import create_reference_model
from trl.import_utils import is_peft_available
from trl.trainer.dpo_trainer import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length

from huggingface_hub import login
from random import sample

In [None]:
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# Definitions

In [None]:
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
    """
    Estimate the average number of characters per token in the dataset.
    """
    total_characters, total_tokens = 0, 0
    for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
        text = prepare_sample_text(example)
        total_characters += len(text)
        if tokenizer.is_fast:
            total_tokens += len(tokenizer(text).tokens())
        else:
            total_tokens += len(tokenizer.tokenize(text))

    return total_characters / total_tokens

def prepare_sample_text(example):
    """Prepare the text from a sample of the dataset."""
    text = f"### Human: {example['question']}\n ### Assistant: {example['response_j']}"
    return text

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['question'])):
        text = f"### Human: {example['question'][i]}\n ### Assistant: {example['answer'][i]}"
        output_texts.append(text)
    return output_texts


def find_all_linear_names(model):
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, bnb.nn.Linear4bit):
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if "lm_head" in lora_module_names:  # needed for 16-bit
        lora_module_names.remove("lm_head")
    return list(lora_module_names)


def create_peft_model(model,
                      r=64,
                      lora_alpha=16,
                      lora_dropout=0.1,
                      bias='none',
                      task_type='CAUSAL_LM',
                      gradient_checkpointing=True,
                      bf16=True):

    # prepare int-4 model for training
    model = prepare_model_for_kbit_training(
        model, use_gradient_checkpointing=gradient_checkpointing
    )
    if gradient_checkpointing:
        model.gradient_checkpointing_enable()

    # get lora target modules
    modules = find_all_linear_names(model)
    print(f"Found {len(modules)} modules to quantize: {modules}")

    peft_config = LoraConfig(
        r=r,
        lora_alpha=lora_alpha,
        target_modules=modules,
        lora_dropout=lora_dropout,
        bias=bias,
        task_type=task_type,
    )

    model = get_peft_model(model, peft_config)

    # pre-process the model by upcasting the layer norms in float 32 for
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            if bf16:
                module = module.to(torch.bfloat16)
        if "norm" in name:
            module = module.to(torch.float32)
        if "lm_head" in name or "embed_tokens" in name:
            if hasattr(module, "weight"):
                if bf16 and module.weight.dtype == torch.float32:
                    module = module.to(torch.bfloat16)

    model.print_trainable_parameters()
    return model

class PeftSavingCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        kwargs["model"].save_pretrained(checkpoint_path)

        if "pytorch_model.bin" in os.listdir(checkpoint_path):
            os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))

# Merge SFT model

In [None]:
from getpass import getpass

hf_token = getpass()
wandb_token = getpass()

from huggingface_hub import login
from peft import PeftModel
login(hf_token)

model_name = 'meta-llama/Llama-2-7b-hf'
peft_model_id = 'dhmeltzer/Llama-2-7b-hf-wiki-no-group-by-length_r_64_alpha_16'

model = AutoModelForCausalLM.from_pretrained(
        model_name,
        return_dict=True,
        torch_dtype=torch.bfloat16
    )
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = PeftModel.from_pretrained(model, peft_model_id)
model.eval()
model = model.merge_and_unload()

model.push_to_hub('dhmeltzer/Llama-2-7b-hf-wiki-no-gl-r-64-alpha-16-full')

# Form Dataset

In [None]:
def chosen_rejected(example):
    scores = example['answers.score']
    answers = example['answers.text']

    if scores[0]>scores[1]:
        return {'chosen':answers[0],'rejected':answers[1]}
    else:
        return {'chosen':answers[1],'rejected':answers[0]}

def format_prompt(example):
    """Prepare the text from a sample of the dataset."""
    text = f"### Human: {example['title_body']}\n ### Assistant: "
    return text


def reformat_dataset(ds,tokenizer):
    ds = ds.map(lambda x: chosen_rejected(x))
    ds = ds.remove_columns(['answers.score','answers.text','title_body'])

    def tot_length(example):
        return sum(len(tokenizer(example[key])['input_ids']) for key
               in ['prompt','chosen','rejected'])

    ds = ds.map(lambda x: {'lengths': tot_length(x)})

    return ds


def choose_random_answers(example):
    scores = example['answers.score']
    answers = example['answers.text']

    indices = sample(range(len(example['answers.score'])),2)
    scores_samp = [scores[i] for i in indices]
    answers_samp = [answers[i] for i in indices]

    return {'answers.score':scores_samp,'answers.text':answers_samp}


In [None]:
tokenizer = AutoTokenizer.from_pretrained(
        'meta-llama/Llama-2-7b-hf'
    )

ds_RM = datasets.load_from_disk('../../ELI5_dataset/data/RM_non_toxic')
features = list(ds_RM['train'].features)
ds_RM = ds_RM.remove_columns([col for col in features if
                             col not in ['answers.score',
                                         'answers.text',
                                         'title_body']])

ds_RM = ds_RM.map(lambda x: {'prompt':format_prompt(x)})

ds_RM_top_2 = ds_RM.map(lambda x: {'answers.score': x['answers.score'][:2],
                                   'answers.text':x['answers.text'][:2]})
ds_RM_top_2 = reformat_dataset(ds_RM_top_2,tokenizer)


ds_RM_contrast = ds_RM.map(lambda x: {'answers.score':[x['answers.score'][i] for i in [0,-1]],
                                   'answers.text':[x['answers.text'][i] for i in [0,-1]]})
ds_RM_contrast = reformat_dataset(ds_RM_contrast,tokenizer)

ds_RM_random = ds_RM.map(choose_random_answers)
ds_RM_random = reformat_dataset(ds_RM_random,tokenizer)

ds_RM_top_2.save_to_disk('../data/ds_RM_top_2')
ds_RM_contrast.save_to_disk('../data/ds_RM_contrast')
ds_RM_random.save_to_disk('../data/ds_RM_random')

# Training


In [None]:
from huggingface_hub import login
login()
#wandb.login()
ds_RM_top_2 = datasets.load_from_disk('../data/ds_RM_top_2')
ds_RM_contrast = datasets.load_from_disk('../data/ds_RM_contrast')
ds_RM_random = datasets.load_from_disk('../data/ds_RM_random')


In [None]:
def DPO_training(model_id,
                dataset,
                hf_token,
                gradient_checkpointing=True,
                r=64,
                lora_alpha=16,
                lora_dropout=0.1,
                bias='none',
                task_type='CAUSAL_LM',
                max_prompt_length=2048,
                max_length=4096,
                epochs = 1,
                max_steps = -1,
                lr=5e-5,
                weight_decay=.1,
                per_device_train_batch_size=16,
                per_device_eval_batch_size=32,
                gradient_accumulation_steps=8,
                optim='adamw_torch_fused',
                warmup_ratio=0.03,
                group_by_length=True,
                dataloader_num_workers=2,
                logging_steps=20,
                save_total_limit=3,
                save_strategy='steps',
                save_steps =.1,
                eval_steps=.1,
                load_best_model_at_end=True,
                project_name='DPO_training_dm',
                entity='ft-llmmm',
                torch_compile=False,
                length_column_name='lengths',
                 repo_id = None):

    if torch.cuda.get_device_capability()[0] == 8:
        bf16=True,
        fp16=False
    else:
        bf16=False
        fp16=True


    bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

    model = AutoModelForCausalLM.from_pretrained(
            model_id,
            use_cache=False
            if gradient_checkpointing
            else True,  # this is needed for gradient checkpointing
            device_map="auto",
            quantization_config=bnb_config,
            use_auth_token=hf_token
        )

    model = create_peft_model(model,
                            r=r,
                            lora_alpha=lora_alpha,
                            lora_dropout=lora_dropout,
                            bias=bias,
                            task_type=task_type,
                            gradient_checkpointing=gradient_checkpointing,
                            bf16=bf16)

    model.train()

    tokenizer = AutoTokenizer.from_pretrained(
            model_id,
        )

    tokenizer.pad_token = tokenizer.eos_token

    output_dir = f'./{model_id}_DPO_r_{r}_alpha_{alpha}'

    training_args = TrainingArguments(
        logging_dir =output_dir+'./logs',
        output_dir= output_dir,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        bf16=bf16,  # Use BF16 if available
        fp16=fp16,
        learning_rate=lr,
        num_train_epochs=epochs,
        max_steps = max_steps,
        gradient_checkpointing=gradient_checkpointing,
        optim=optim,
        warmup_ratio=warmup_ratio,
        weight_decay = weight_decay,
        gradient_accumulation_steps=gradient_accumulation_steps,
        group_by_length=group_by_length,
        # logging strategies
        logging_strategy="steps",
        logging_steps=logging_steps,
        save_strategy=save_strategy,
        evaluation_strategy = save_strategy,
        save_steps = save_steps,
        eval_steps = eval_steps,
    #   log_level = 'error',
        hub_token=hf_token,
        report_to='wandb',
        dataloader_num_workers = dataloader_num_workers,
        load_best_model_at_end=load_best_model_at_end,
        save_total_limit = save_total_limit,
        remove_unused_columns=False,
        disable_tqdm=False,
        torch_compile=torch_compile,
        length_column_name=length_column_name
        #max_grad_norm=0.3
    )

    dpo_trainer = DPOTrainer(
        model,
        args=training_args,
        beta=1,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        max_prompt_length=max_prompt_length,
        max_length=max_length,
    )

    original_performance = trainer.evaluate()
    run.log({'initial-performance': wandb.Table(dataframe=pd.DataFrame(original_performance, index=["Performance"]))})

    dpo_trainer.train()

    if repo_id:
        eval_result = trainer.evaluate()
        trainer.create_model_card(model_name=repo_id)
        trainer.push_to_hub()

    final_performance = trainer.evaluate()
    run.log({'final-performance': wandb.Table(dataframe=pd.DataFrame(final_performance, index=["Performance"]))})

    trainer.save_model(output_dir)

In [None]:
gradient_checkpointing=True
r=64
lora_alpha=16
lora_dropout=0.1
bias='none'
task_type='CAUSAL_LM'
max_seq_length=512
epochs = 1
max_steps = -1
lr=2e-4
weight_decay=.01
per_device_train_batch_size=1
per_device_eval_batch_size=1
gradient_accumulation_steps=1
optim='paged_adamw_32bit'
warmup_ratio=0.03
group_by_length=True
dataloader_num_workers=2
logging_steps=10
save_total_limit=3
save_strategy='steps'
save_steps =.2
eval_steps=.2
load_best_model_at_end=True
project_name='DPO_training_dm'
entity='ft-llmmm'
torch_compile=False
length_column_name='lengths'

SFT_model_id = 'dhmeltzer/Llama-2-7b-hf-wiki-no-gl-r-64-alpha-16-full'
base_model_id = 'meta-llama/Llama-2-7b-hf'

#SFT_model_id = 'distilgpt2'
#base_model_id = SFT_model_id

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
        SFT_model_id,
        use_cache=False
        if gradient_checkpointing
        else True,  # this is needed for gradient checkpointing
        device_map="auto",
        quantization_config=bnb_config,
        #use_auth_token=hf_token
    )

model = create_peft_model(model,
                          r=r,
                          lora_alpha=lora_alpha,
                          lora_dropout=lora_dropout,
                          bias=bias,
                          task_type=task_type,
                          gradient_checkpointing=gradient_checkpointing,
                          bf16=bf16)

model.train()

tokenizer = AutoTokenizer.from_pretrained(
        base_model_id,
        #use_auth_token=hf_token
    )

tokenizer.pad_token = tokenizer.eos_token



In [None]:
output_dir = f'./SFT_wiki_no_gl_DPO/models'



train_dataset = ds_RM_top_2['train']
eval_dataset = ds_RM_top_2['validation']

In [None]:
del DataCollator
del dpo_trainer

In [None]:
gc.collect()
torch.cuda.empty_cache()