<a href="https://colab.research.google.com/github/david-meltzer/LLMs/blob/main/training/david/DPO/DPO_fine_tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd drive/MyDrive/LLMs/Fine-tuning/DPO

Mounted at /content/drive
/content/drive/MyDrive/LLMs/Fine-tuning/DPO


In [None]:
!pip install peft==0.5.0 -qqq
!pip install bitsandbytes==0.41.1 -qqq
!pip install safetensors>=0.3.1 -qqq
#!pip install -U trl
!pip install wandb -qqq
!pip install tokenizers>=0.13.3 -qqq
!pip install -U transformers -qqq
!pip install accelerate==0.21.0 -qqq
!pip install git+https://github.com/huggingface/trl -qqq

!python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
!pip install ninja packaging
!pip install flash-attn --no-build-isolation

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.6/85.6 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.5/7.5 MB[0m [31m29.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m251.2/251.2 kB[0m [31m27.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m51.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m91.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.6/92.6 MB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os
#from dataclasses import dataclass, field
#from typing import Optional

import warnings
from collections import defaultdict
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import datasets
from datasets import Dataset, load_dataset
import transformers
from transformers import (AutoTokenizer,
                          AutoModelForCausalLM,
                          DataCollator,
                          PreTrainedModel,
                          PreTrainedTokenizerBase,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForLanguageModeling,
                          BitsAndBytesConfig)

from transformers.trainer_callback import TrainerCallback

import gc

import os
from google.colab import runtime
import pandas as pd

import accelerate
import bitsandbytes as bnb
import wandb
from peft import (LoraConfig,
                  get_peft_model,
                  prepare_model_for_kbit_training,
                  PeftModel,
                  PeftConfig)
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datetime import datetime
from huggingface_hub import login

from peft.tuners.lora import LoraLayer

from tqdm import tqdm

import trl
from trl import DPOTrainer
from trl.models import create_reference_model
from trl.import_utils import is_peft_available
from trl.trainer.dpo_trainer import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length

from huggingface_hub import login
from random import sample

In [None]:
from trl.trainer.dpo_trainer import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length


# Definitions

In [None]:
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
    """
    Estimate the average number of characters per token in the dataset.
    """
    total_characters, total_tokens = 0, 0
    for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
        text = prepare_sample_text(example)
        total_characters += len(text)
        if tokenizer.is_fast:
            total_tokens += len(tokenizer(text).tokens())
        else:
            total_tokens += len(tokenizer.tokenize(text))

    return total_characters / total_tokens

def prepare_sample_text(example):
    """Prepare the text from a sample of the dataset."""
    text = f"### Human: {example['question']}\n ### Assistant: {example['response_j']}"
    return text

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['question'])):
        text = f"### Human: {example['question'][i]}\n ### Assistant: {example['answer'][i]}"
        output_texts.append(text)
    return output_texts


def find_all_linear_names(model):
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, bnb.nn.Linear4bit):
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if "lm_head" in lora_module_names:  # needed for 16-bit
        lora_module_names.remove("lm_head")
    return list(lora_module_names)


def create_peft_model(model,
                      r=64,
                      lora_alpha=16,
                      lora_dropout=0.1,
                      bias='none',
                      task_type='CAUSAL_LM',
                      gradient_checkpointing=True,
                      bf16=True):

    # prepare int-4 model for training
    model = prepare_model_for_kbit_training(
        model, use_gradient_checkpointing=gradient_checkpointing
    )
    if gradient_checkpointing:
        model.gradient_checkpointing_enable()

    # get lora target modules
    modules = find_all_linear_names(model)
    print(f"Found {len(modules)} modules to quantize: {modules}")

    peft_config = LoraConfig(
        r=r,
        lora_alpha=lora_alpha,
        target_modules=modules,
        lora_dropout=lora_dropout,
        bias=bias,
        task_type=task_type,
    )

    model = get_peft_model(model, peft_config)

    # pre-process the model by upcasting the layer norms in float 32 for
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            if bf16:
                module = module.to(torch.bfloat16)
        if "norm" in name:
            module = module.to(torch.float32)
        if "lm_head" in name or "embed_tokens" in name:
            if hasattr(module, "weight"):
                if bf16 and module.weight.dtype == torch.float32:
                    module = module.to(torch.bfloat16)

    model.print_trainable_parameters()
    return model

class PeftSavingCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        kwargs["model"].save_pretrained(checkpoint_path)

        if "pytorch_model.bin" in os.listdir(checkpoint_path):
            os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))

# Merge SFT model

In [None]:
from getpass import getpass
hf_token = getpass()
from huggingface_hub import login
login(hf_token)

··········
Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:
from getpass import getpass

hf_token = getpass()
wandb_token = getpass()

from huggingface_hub import login
from peft import PeftModel
login(hf_token)



··········
··········
Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:
base_model = 'meta-llama/Llama-2-7b-hf'

adapter_models = [
    'dhmeltzer/llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16',
    'dhmeltzer/llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16',
    'dhmeltzer/llama-7b-SFT_ds_eli5_1024_r_64_alpha_16'
]

model = AutoModelForCausalLM.from_pretrained(
        base_model,
        return_dict=True,
        torch_dtype=torch.bfloat16
    )
tokenizer = AutoTokenizer.from_pretrained(base_model)

for adapter_model in adapter_models:
    model = PeftModel.from_pretrained(model, adapter_model)
    model.eval()
    model = model.merge_and_unload()

    model.push_to_hub(f'{adapter_model}_merged')
    tokenizer.push_to_hub(f'{adapter_model}_merged')

In [None]:
model_ids = [
    'dhmeltzer/llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16_merged',
    'dhmeltzer/llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16_merged',
    'dhmeltzer/llama-7b-SFT_ds_eli5_1024_r_64_alpha_16_merged'
]

In [None]:
!pip install huggingface_hub
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
del model
gc.enable()
gc.collect()
torch.cuda.empty_cache()

In [None]:
model_id

'dhmeltzer/llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16_merged'

In [None]:
for model_id in model_ids[1:]:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map='auto',
        torch_dtype=torch.bfloat16,
    )

    model.push_to_hub(model_id,safe_serialization=True)

    del model

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00002.bin:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/183 [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

In [None]:
model.push_to_hub(model_id,safe_serialization=True)

'dhmeltzer/llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16_merged'

In [None]:
model.push_to_hub(model_id)

# Form Dataset

## Definitions

In [None]:
def chosen_rejected(example):
    scores = example['answers.score']
    answers = example['answers.text']

    if scores[0]>scores[1]:
        return {'chosen':answers[0],'rejected':answers[1]}
    else:
        return {'chosen':answers[1],'rejected':answers[0]}

def format_prompt(example):
    """Prepare the text from a sample of the dataset."""
    text = f"### Human: {example['title_body']}\n ### Assistant: "
    return text

def reformat_dataset(ds,tokenizer):

    ds = ds.map(lambda x: chosen_rejected(x))
    ds = ds.remove_columns(['answers.score','answers.text','title_body'])

    def tot_length(example):

        longer_answer = max(len(tokenizer(example[key])['input_ids']) for key
               in ['chosen','rejected'])
        tot_length = longer_answer + len(tokenizer(example['prompt'])['input_ids'])
        return tot_length

    ds = ds.map(lambda x: {'lengths': tot_length(x)})

    return ds


def choose_random_answers(example):
    scores = example['answers.score']
    answers = example['answers.text']

    indices = sample(range(len(example['answers.score'])),2)
    scores_samp = [scores[i] for i in indices]
    answers_samp = [answers[i] for i in indices]

    return {'answers.score':scores_samp,'answers.text':answers_samp}


In [None]:
from getpass import getpass
hf_token= getpass('input HF token')
wandb_token = getpass('input wandb token')

input HF token··········
input wandb token··········


In [None]:
ds_RM = datasets.load_from_disk('../../ELI5_dataset/data/RM_non_toxic')
features = list(ds_RM['train'].features)
ds_RM = ds_RM.remove_columns([col for col in features if
                             col not in ['answers.score',
                                         'answers.text',
                                         'title_body']])

ds_RM = ds_RM.map(lambda x: {'prompt':format_prompt(x)})

## Pairs of answers

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
        'meta-llama/Llama-2-7b-hf',
        token = hf_token
    )

ds_RM_top_2 = ds_RM.map(lambda x: {'answers.score': x['answers.score'][:2],
                                   'answers.text':x['answers.text'][:2]})
ds_RM_top_2 = reformat_dataset(ds_RM_top_2,tokenizer)


ds_RM_contrast = ds_RM.map(lambda x: {'answers.score':[x['answers.score'][i] for i in [0,-1]],
                                   'answers.text':[x['answers.text'][i] for i in [0,-1]]})
ds_RM_contrast = reformat_dataset(ds_RM_contrast,tokenizer)

ds_RM_random = ds_RM.map(choose_random_answers)
ds_RM_random = reformat_dataset(ds_RM_random,tokenizer)

ds_RM_top_2.save_to_disk('../data/ds_RM_top_2')
ds_RM_contrast.save_to_disk('../data/ds_RM_contrast')
ds_RM_random.save_to_disk('../data/ds_RM_random')

In [None]:
ds_RM_eq_2 = ds_RM.filter(lambda x: x['num_answers']==2)
ds_RM_ge_2 = ds_RM.filter(lambda x: x['num_answers']>2)

In [None]:
for key in ds_RM_eq_2:
    print(len(ds_RM_eq_2[key])/len(ds_RM[key]))

In [None]:
import matplotlib.pyplot as plt

fig, ax= plt.subplots()

ax.hist([score for l in ds_RM['train']['answers.score'] for score in l],bins=200);
ax.set_yscale('log')
plt.show()

In [None]:
ds_RM_top_2 = datasets.load_from_disk('../data/ds_RM_top_2')
ds_RM_contrast = datasets.load_from_disk('../data/ds_RM_contrast')
ds_RM_random = datasets.load_from_disk('../data/ds_RM_random')

ds_RM_top_2_filt = {}
ds_RM_contrast_filt = {}
ds_RM_random_filt = {}

for max_length in [1024,2048,4096]:

    ds_RM_top_2_filt[max_length] = ds_RM_top_2.filter(
        lambda x: len(x["prompt"]) + len(x["chosen"]) <= max_length
        and len(x["prompt"]) + len(x["rejected"]) <= max_length
    )

    ds_RM_contrast_filt[max_length] = ds_RM_contrast.filter(
        lambda x: len(x["prompt"]) + len(x["chosen"]) <= max_length
        and len(x["prompt"]) + len(x["rejected"]) <= max_length
    )

    ds_RM_random_filt[max_length] = ds_RM_random.filter(
        lambda x: len(x["prompt"]) + len(x["chosen"]) <= max_length
        and len(x["prompt"]) + len(x["rejected"]) <= max_length
    )

ds_dict = {}
ds_dict['top_2'] = ds_RM_top_2_filt
ds_dict['contrast'] = ds_RM_contrast_filt
ds_dict['random'] = ds_RM_random_filt

for key in ds_dict:
    ds_dict[key][1024].save_to_disk(f'./data/ds_RM_{key}_1024')

## Weighted Answers

In [None]:
from itertools import combinations
import random
from torch.nn.utils.rnn import pad_sequence
from dataclasses import dataclass

@dataclass
class DPODataCollatorWithPadding_with_margin:
    r"""
    DPO DataCollator class that pads the inputs to the maximum length of the batch.
    Args:
        tokenizer (`PreTrainedTokenizerBase`):
            The tokenizer used for encoding the data.
        padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
            padding_strategy to pass to the tokenizer.
        max_length (`Optional[int]`, `optional`, defaults to `None`):
            The maximum length of the sequence to be processed.
        max_prompt_length (`Optional[int]`, `optional`, defaults to `None`):
            The maximum length of the prompt to be processed.
        label_pad_token_id (`int`, defaults to -100):
            The label used for masking.
        padding_value (`int`, defaults to 0):
            The value used for padding.
        truncation_mode: (`str`, defaults to "keep_end"):
            The truncation mode to use when truncating the prompt.
    """
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_prompt_length: Optional[int] = None
    label_pad_token_id: int = -100
    padding_value: int = 0
    truncation_mode: str = "keep_end"

    def tokenize_batch_element(
        self,
        prompt: str,
        chosen: str,
        rejected: str,
        score_accepted: int,
        score_rejected: int,
        weight: float
    ) -> Dict:
        """Tokenize a single batch element.

        At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
            in case the prompt + chosen or prompt + rejected responses is/are too long. First
            we truncate the prompt; if we're still too long, we truncate the chosen/rejected.

        We also create the labels for the chosen/rejected responses, which are of length equal to
            the sum of the length of the prompt and the chosen/rejected response, with
            label_pad_token_id  for the prompt tokens.
        """
        chosen_tokens = self.tokenizer(chosen, add_special_tokens=False)
        rejected_tokens = self.tokenizer(rejected, add_special_tokens=False)
        prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)

        assert self.tokenizer.eos_token_id not in prompt_tokens["input_ids"], f"Prompt contains EOS token: {prompt}"
        assert (
            self.tokenizer.eos_token_id not in chosen_tokens["input_ids"]
        ), f"Chosen response contains EOS token: {chosen}"
        assert (
            self.tokenizer.eos_token_id not in rejected_tokens["input_ids"]
        ), f"Rejected response contains EOS token: {rejected}"

        chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
        chosen_tokens["attention_mask"].append(1)

        rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
        rejected_tokens["attention_mask"].append(1)

        longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

        # if combined sequence is too long, truncate the prompt
        if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
            if self.truncation_mode == "keep_start":
                prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()}
            elif self.truncation_mode == "keep_end":
                prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()}
            else:
                raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")

        # if that's still too long, truncate the response
        if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
            chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()}
            rejected_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items()}

        # Create labels
        chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
        rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens}
        chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
        chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
            prompt_tokens["input_ids"]
        )
        rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
        rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
            prompt_tokens["input_ids"]
        )

        batch = {}
        batch["prompt"] = prompt
        batch["chosen"] = prompt + chosen
        batch["rejected"] = prompt + rejected
        batch["chosen_response_only"] = chosen
        batch["rejected_response_only"] = rejected
        batch['score_accepted'] = score_accepted
        batch['score_rejected'] = score_rejected
        batch['weight'] = weight

        for k, toks in {
            "chosen": chosen_sequence_tokens,
            "rejected": rejected_sequence_tokens,
            "prompt": prompt_tokens,
        }.items():
            for type_key, tokens in toks.items():
                if type_key == "token_type_ids":
                    continue
                batch[f"{k}_{type_key}"] = tokens

        return batch

    def collate(self, batch):
        # first, pad everything to the same length
        padded_batch = {}
        for k in batch[0].keys():
            if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
                # adapted from https://stackoverflow.com/questions/73256206
                if "prompt" in k:
                    to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
                else:
                    to_pad = [torch.LongTensor(ex[k]) for ex in batch]
                if k.endswith("_input_ids"):
                    padding_value = self.tokenizer.pad_token_id
                elif k.endswith("_labels"):
                    padding_value = self.label_pad_token_id
                elif k.endswith("_attention_mask"):
                    padding_value = self.padding_value
                else:
                    raise ValueError(f"Unexpected key in batch '{k}'")

                padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
                # for the prompt, flip back so padding is on left side
                if "prompt" in k:
                    padded_batch[k] = padded_batch[k].flip(dims=[1])
            elif k.startswith("score") or k.startswith('weight'):
                padded_batch[k] = torch.tensor([ex[k] for ex in batch])
            else:
                padded_batch[k] = [ex[k] for ex in batch]

        return padded_batch

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        tokenized_batch = []

        for feature in features:
            prompt = feature["prompt"]
            chosen = feature["chosen"]
            rejected = feature["rejected"]
            score_accepted = feature['score_accepted']
            score_rejected = feature['score_rejected']
            weight = feature['weight']

            batch_element = self.tokenize_batch_element(prompt,
                                                        chosen,
                                                        rejected,
                                                        score_accepted,
                                                        score_rejected,
                                                        weight)

            print(f'batch element is {batch_element}')
            tokenized_batch.append(batch_element)
            print(f'tokenized_batch is {tokenized_batch}')
        # return collated batch

        print(f'collated output is {self.collate(tokenized_batch)}')

        return self.collate(tokenized_batch)

class DPOTrainer_with_margins(DPOTrainer):
    """
    This class implements two modifications of the DPOTrainer:
        1) Can now reweight multiple pairs of responses for a given prompt to avoid overfitting.
           That is, if a prompt has 10 chosen/rejected pairs, we can reweight these samples to avoid overfitting to this prompt.
        2) Can introduce a margin "m" into the DPO loss function L = - log(sigmoid(r_chosen - r_rejected -m)).
           Size of "m" determines how much more r_chosen is preferred over r_rejected.
           We will take m = log(score_chosen - score_rejected).

    With the exception of the new hyperparameter alpha, all params below are copied from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py.

    Args:
        model (`transformers.PreTrainedModel`):
            The model to train, preferably an `AutoModelForSequenceClassification`.
        ref_model (`PreTrainedModelWrapper`):
            Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
            reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
        beta (`float`, defaults to 0.1):
            The beta factor in DPO loss. Higher beta means less divergence from the initial policy.
        args (`transformers.TrainingArguments`):
            The arguments to use for training.
        data_collator (`transformers.DataCollator`):
            The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
            which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
        label_pad_token_id (`int`, defaults to `-100`):
            The label pad token id. This argument is required if you want to use the default data collator.
        padding_value (`int`, defaults to `0`):
            The padding value. This argument is required if you want to use the default data collator.
        truncation_mode (`str`, defaults to `keep_end`):
            The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
        train_dataset (`datasets.Dataset`):
            The dataset to use for training.
        eval_dataset (`datasets.Dataset`):
            The dataset to use for evaluation.
        tokenizer (`transformers.PreTrainedTokenizerBase`):
            The tokenizer to use for training. This argument is required if you want to use the default data collator.
        model_init (`Callable[[], transformers.PreTrainedModel]`):
            The model initializer to use for training. If None is specified, the default model initializer will be used.
        callbacks (`List[transformers.TrainerCallback]`):
            The callbacks to use for training.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
            The optimizer and scheduler to use for training.
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
            The function to use to preprocess the logits before computing the metrics.
        max_length (`int`, defaults to `None`):
            The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
        max_prompt_length (`int`, defaults to `None`):
            The maximum length of the prompt. This argument is required if you want to use the default data collator.
        peft_config (`Dict`, defaults to `None`):
            The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
        disable_dropout (`bool`, defaults to `True`):
            Whether or not to disable dropouts in `model` and `ref_model`.
        allow_margin (`bool`, defaults to `True`):
            Set to true to include margin term in loss function.
        alpha (`float`, defaults to `1`.):
            Coefficient controlling strength of effect of margin on loss..
    """

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
        beta: float = 0.1,
        args: TrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        label_pad_token_id: int = -100,
        padding_value: int = 0,
        truncation_mode: str = "keep_end",
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
            None,
            None,
        ),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        max_length: Optional[int] = None,
        max_prompt_length: Optional[int] = None,
        peft_config: Optional[Dict] = None,
        disable_dropout: bool = True,
        include_margin: bool = True,
        alpha: float = 1.,
    ):
        if data_collator is None:
            data_collator = DPODataCollatorWithPadding_with_margin(
                tokenizer,
                max_length=max_length,
                max_prompt_length=max_prompt_length,
                label_pad_token_id=label_pad_token_id,
                padding_value=padding_value,
                truncation_mode=truncation_mode,
            )

        super().__init__(
            model,
            ref_model,
            beta,
            args,
            data_collator,
            label_pad_token_id,
            padding_value,
            truncation_mode,
            train_dataset,
            eval_dataset,
            tokenizer,
            model_init,
            callbacks,
            optimizers,
            preprocess_logits_for_metrics,
            max_length,
            max_prompt_length,
            peft_config,
            disable_dropout
            )

        self.alpha = alpha

    def dpo_loss(
            self,
            policy_chosen_logps: torch.FloatTensor,
            policy_rejected_logps: torch.FloatTensor,
            reference_chosen_logps: torch.FloatTensor,
            reference_rejected_logps: torch.FloatTensor,
            scores_accepted: torch.FloatTensor,
            scores_rejected: torch.FloatTensor,
            weights: torch.FloatTensor,
            reference_free: bool = False,
            ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
            """Compute the DPO loss for a batch of policy and reference model log probabilities.
            Args:
                policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
                policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
                reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
                reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
                beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.
                reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
                alpha: Controls strength of margin in computation of loss.
                include_margin: Set to True to incorporate margin in loss function.
                allows_reweighting: Whether to include sample weights in loss function.

            Returns:
                A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
                The losses tensor contains the DPO loss for each example in the batch.
                The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
            """

            pi_logratios = policy_chosen_logps - policy_rejected_logps
            ref_logratios = reference_chosen_logps - reference_rejected_logps

            if reference_free:
                ref_logratios = 0

            logits = pi_logratios - ref_logratios

            margin = torch.log(torch.tensor(scores_accepted).to('cuda')-
                               torch.tensor(scores_rejected).to('cuda'))

            #print(f'policy_chosen_logps is {policy_chosen_logps}')

            losses = -F.logsigmoid(self.beta * logits - self.alpha*margin)

            losses *= torch.tensor(weights).to('cuda')
            chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
            rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

            return losses, chosen_rewards, rejected_rewards

    def get_batch_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
        ):
        """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
        ) = self.concatenated_forward(model, batch)

        #print(f'policy_chosen_logps is {policy_chosen_logps}')
        #print(f'policy_rejected_logps is {policy_rejected_logps}')
        #print(f'policy_chosen_logits is {policy_chosen_logits}')
        #print(f'policy_rejected_logits is {policy_rejected_logits}')

        with torch.no_grad():
            if self.ref_model is None:
                with self.accelerator.unwrap_model(self.model).disable_adapter():
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        _,
                        _,
                    ) = self.concatenated_forward(self.model, batch)
            else:
                (
                    reference_chosen_logps,
                    reference_rejected_logps,
                    _,
                    _,
                ) = self.concatenated_forward(self.ref_model, batch)


        losses, chosen_rewards, rejected_rewards = self.dpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps,
            batch['score_accepted'],
            batch['score_rejected'],
            batch['weight']
        )

        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().numpy().mean()
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().numpy().mean()
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().numpy().mean()
        metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().numpy().mean()
        metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().numpy().mean()
        metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean()
        metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().numpy().mean()
        metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean()

        return losses.mean(), metrics

In [None]:
def pair_answers(x):

    zipped = list(zip(x['answers.score'],x['answers.text']))
    zipped = list(combinations(zipped,2))
    scores_paired = [(z[0][0],z[1][0]) for z in zipped]
    answers_paired = [(z[0][1],z[1][1]) for z in zipped]
    num_pairs = len(answers_paired)

    if num_pairs<=10:
        return {'prompt':[x['prompt'] for _ in range(num_pairs)],
            'scores':scores_paired,
            'answers':answers_paired,
            'weight':[1./len(answers_paired) for _ in range(num_pairs)]}
    else:
        return {'prompt':[x['prompt'] for _ in range(10)],
            'scores':random.sample(scores_paired,10),
            'answers':random.sample(answers_paired,10),
            'weight':[1./10 for _ in range(10)]}

In [None]:
ds_RM_paired = ds_RM.map(pair_answers,remove_columns=['answers.score', 'answers.text',
                                                      'title_body','prompt'])

In [None]:
#ds_RM_fixed=datasets.DatasetDict()
for split in ['train','validation','test']:
    print(f'working on split {split}')
    length = len(ds_RM_paired[split])
    ds_RM_paired[split] = datasets.concatenate_datasets([Dataset.from_dict(ds_RM_paired[split][i])\
                               for i in range(length)])

working on split train
working on split validation
working on split test


In [None]:
def reformat_paired_dataset(x):

    if x['scores'][0]>x['scores'][1]:
        score_accepted = x['scores'][0]
        score_rejected = x['scores'][1]
        answer_accepted = x['answers'][0]
        answer_rejected = x['answers'][1]
    else:
        score_accepted = x['scores'][1]
        score_rejected = x['scores'][0]
        answer_accepted = x['answers'][1]
        answer_rejected = x['answers'][0]

    return {'chosen':answer_accepted,
            'rejected':answer_rejected,
            'score_accepted':score_accepted,
            'score_rejected':score_rejected}

def add_length_index(ds,tokenizer):

    def tot_length(example):
        longer_answer = max(len(tokenizer(example[key])['input_ids']) for key
               in ['chosen','rejected'])
        tot_length = longer_answer + len(tokenizer(example['prompt'])['input_ids'])
        return tot_length

    ds = ds.map(lambda x: {'lengths': tot_length(x)})
    return ds

In [None]:
from getpass import getpass
hf_token = getpass('hf_token: ')

tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf',
                                          token = hf_token)


In [None]:
ds_RM_paired['train'][5]['scores']

[[178, 73], [178, 57], [73, 57]]

In [None]:
ds_RM_paired_V2 = ds_RM_paired.map(lambda x:reformat_paired_dataset(x))
ds_RM_paired_V2 = ds_RM_paired_V2.remove_columns(['scores','answers'])
ds_RM_paired_V2 = add_length_index(ds_RM_paired_V2,tokenizer)

Map:   0%|          | 0/112659 [00:00<?, ? examples/s]

Map:   0%|          | 0/7204 [00:00<?, ? examples/s]

Map:   0%|          | 0/10393 [00:00<?, ? examples/s]

Map:   0%|          | 0/112659 [00:00<?, ? examples/s]

Map:   0%|          | 0/7204 [00:00<?, ? examples/s]

Map:   0%|          | 0/10393 [00:00<?, ? examples/s]

In [None]:
ds_RM_paired_V2.save_to_disk('./data/ds_RM_paired')

Saving the dataset (0/1 shards):   0%|          | 0/112659 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7204 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10393 [00:00<?, ? examples/s]

# Training


## Single Pair of Answers

In [None]:
from getpass import getpass
wandb_token = getpass('input wandb token ')
hf_token = getpass('input hf token ')

input wandb token ··········
input hf token ··········


In [None]:
model_name_simplifier = {}

model_name_simplifier['dhmeltzer/llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16_merged'] ='llama-7b-SFT-qlora-eli5-wiki'
model_name_simplifier['dhmeltzer/llama-7b-SFT_ds_eli5_1024_r_64_alpha_16_merged']='llama-7b-SFT-qlora-eli5'
model_name_simplifier['dhmeltzer/llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16_merged']='llama-7b-SFT-qlora-wiki'

In [None]:
model_id = 'dhmeltzer/llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16_merged'

In [None]:
model_id = model_id
model_name = model_name_simplifier[model_id]+'_DPO'

dataset_path = './data/ds_RM_top_2_1024'
epochs = 1
optim = 'paged_adamw_8bit'
per_device_train_batch_size=32
per_device_eval_batch_size = 32
gradient_accumulation_steps=4

now = datetime.now()
time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")

ds_name = dataset_path.split('/')[-1]
output_dir = f'./{model_name}_{ds_name}/models'
logging_dir = f'{output_dir}/logs'

run_name = f'{model_name}_{ds_name}_{time_stamp}'
optim = 'paged_adamw_8bit'

from pathlib import Path
Path(output_dir).mkdir(parents=True, exist_ok=True)
Path(logging_dir).mkdir(parents=True, exist_ok=True)

repo_id = f'{model_name}_{ds_name}'


!python ./run_dpo.py \
--output_dir {output_dir} \
--logging_dir {logging_dir} \
--model_id {model_id} \
--dataset_path {dataset_path} \
--run_name {run_name} \
--repo_id {repo_id} \
--report_to_wandb 1 \
--epochs 1 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--gradient_accumulation_steps 4 \
--optim {optim} \
--hf_token {hf_token} \
--wandb_token {wandb_token} \
--use_flash_attention 1 \
--logging_steps 10 \
--resume_from_checkpoint 0

In [None]:
from google.colab import runtime
runtime.unassign()

## Multiple Pairs of Answers

In [None]:
ds_RM_paired_V2 = datasets.load_from_disk('./data/ds_RM_paired')

In [None]:
ds_RM_paired_V2 = ds_RM_paired_V2.shuffle(seed=12321)

In [None]:
def DPO_training_margin(model_id,
                train_dataset,
                eval_dataset,
                alpha=1,
                include_margin = True,
                use_peft = True,
                hf_token = None,
                wandb_token = None,
                gradient_checkpointing=True,
                r=64,
                lora_alpha=16,
                lora_dropout=0.1,
                beta=.1,
                bias='none',
                task_type='CAUSAL_LM',
                max_prompt_length=4096,
                max_length=4096,
                epochs = 1,
                max_steps = -1,
                lr=5e-4,
                weight_decay=.1,
                per_device_train_batch_size=16,
                per_device_eval_batch_size=32,
                gradient_accumulation_steps=8,
                optim='adamw_torch_fused',
                warmup_ratio=0.03,
                lr_scheduler_type='cosine',
                auto_find_batch_size = True,
                group_by_length=True,
                dataloader_num_workers=2,
                logging_steps=10,
                save_total_limit=3,
                save_strategy='steps',
                save_steps =.1,
                eval_steps=.1,
                load_best_model_at_end=True,
                project_name='DPO_margin_training_dm',
                entity='ft-llmmm',
                torch_compile=False,
                length_column_name='lengths',
                truncation_mode='keep_start',
                repo_id = None,
                output_dir = None,
                hub_strategy = 'every_save'):

    try:
        if torch.cuda.get_device_capability()[0] == 8:
            bf16=True,
            fp16=False
        else:
            bf16=False
            fp16=True
    except:
        bf16=False
        fp16=False

    if use_peft:
        bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        )
    else:
        bnb_config=None


    model = AutoModelForCausalLM.from_pretrained(
            model_id,
            use_cache=False
            if gradient_checkpointing
            else True,  # this is needed for gradient checkpointing
            device_map="auto",
            quantization_config=bnb_config
        )

    model.train()

    if use_peft:
        model = create_peft_model(model,
                                r=r,
                                lora_alpha=lora_alpha,
                                lora_dropout=lora_dropout,
                                bias=bias,
                                task_type=task_type,
                                gradient_checkpointing=gradient_checkpointing,
                                bf16=bf16)

    tokenizer = AutoTokenizer.from_pretrained(
            model_id,
        )

    tokenizer.pad_token = tokenizer.eos_token

    model_name = model_id.split('/')[-1]

    if output_dir is None:
        output_dir = f'./{model_name}_DPO_{ds_name}'

        if use_peft:
            output_dir += f'_r_{r}_alpha_{lora_alpha}'

    if wandb_token is not None:
        wandb.login(key=wandb_token)

        wandb.init(
            job_type='training',
            project=project_name,
            entity=entity,
            name = repo_id
            )

    training_args = TrainingArguments(
        logging_dir =output_dir+'/logs',
        output_dir= output_dir,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        bf16=bf16,  # Use BF16 if available
        fp16=fp16,
        learning_rate=lr,
        num_train_epochs=epochs,
        max_steps = max_steps,
        gradient_checkpointing=gradient_checkpointing,
        optim=optim,
        warmup_ratio=warmup_ratio,
        weight_decay = weight_decay,
        gradient_accumulation_steps=gradient_accumulation_steps,
        group_by_length=group_by_length,
        # logging strategies
        logging_strategy="steps",
        logging_steps=logging_steps,
        save_strategy=save_strategy,
        evaluation_strategy = save_strategy,
        save_steps = save_steps,
        eval_steps = eval_steps,
        lr_scheduler_type=lr_scheduler_type,
    #   log_level = 'error',
        hub_token=hf_token,
        report_to='wandb' if wandb_token else None,
        dataloader_num_workers = dataloader_num_workers,
        load_best_model_at_end=load_best_model_at_end,
        save_total_limit = save_total_limit,
        remove_unused_columns=False,
        disable_tqdm=False,
        torch_compile=torch_compile,
        length_column_name=length_column_name,
        auto_find_batch_size=auto_find_batch_size,
        push_to_hub = True if repo_id else False,
        hub_strategy=hub_strategy,
        hub_model_id=repo_id
        #max_grad_norm=0.3
    )

    #dataset = ds_dict['ds_name']
    #train_dataset = dataset['train']
    #eval_dataset = dataset['validation']

    if include_margin:
        dpo_trainer = DPOTrainer_with_margins(
            model,
            args=training_args,
            beta=beta,
            alpha=alpha,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            max_prompt_length=max_prompt_length,
            max_length=max_length,
            truncation_mode=truncation_mode
        )
    else:
        dpo_trainer = DPOTrainer(
            model,
            args=training_args,
            beta=beta,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            max_prompt_length=max_prompt_length,
            max_length=max_length,
            truncation_mode=truncation_mode
        )

    #original_performance = dpo_trainer.evaluate()
    #wandb.log({'initial-performance': wandb.Table(dataframe=pd.DataFrame(original_performance, index=["Performance"]))})

    dpo_trainer.train()

    if repo_id:
        eval_result = dpo_trainer.evaluate()
        dpo_trainer.create_model_card(model_name=repo_id)
        dpo_trainer.push_to_hub()

    #final_performance = dpo_trainer.evaluate()
    #run.log({'final-performance': wandb.Table(dataframe=pd.DataFrame(final_performance, index=["Performance"]))})

    dpo_trainer.save_model(output_dir)

In [None]:
import gc
gc.enable()
gc.collect()
torch.cuda.empty_cache()

In [None]:
ds_RM_paired_V3 = ds_RM_paired_V2.filter(
    lambda x: x['lengths']<=128
)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#model_id = 'EleutherAI/pythia-70m-deduped'
model_id = 'dhmeltzer/llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16_merged'

train_dataset = ds_RM_paired_V3['train'].select(
    random.sample(range(10000),1000))

eval_dataset = ds_RM_paired_V3['validation'].select(
    random.sample(range(1000),100))

train_dataset.set_format(None)
eval_dataset.set_format(None)

DPO_training_margin(model_id,
                train_dataset,
                eval_dataset,
                include_margin = False,
                alpha = 1,
                hf_token = None,
                wandb_token = None,
                epochs=1,
                max_steps = -1,
                auto_find_batch_size = False,
                project_name='DPO_margin_training_dm',
                repo_id = None,
                output_dir = './test_margins',
                hub_strategy = 'every_save',
                use_peft = True,
                per_device_train_batch_size=2,
                per_device_eval_batch_size=2,
                gradient_accumulation_steps=64,
                optim='paged_adamw_8bit',
                logging_steps=1,
                save_steps=.1,
                eval_steps=.1
                )

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Found 7 modules to quantize: ['o_proj', 'up_proj', 'gate_proj', 'q_proj', 'down_proj', 'k_proj', 'v_proj']
trainable params: 159,907,840 || all params: 6,898,323,456 || trainable%: 2.3180681656919973


RuntimeError: ignored

# Scratch

## older training

In [None]:
def DPO_training(model_id,
                ds_name,
                hf_token = None,
                wandb_token = None,
                gradient_checkpointing=True,
                r=64,
                lora_alpha=16,
                lora_dropout=0.1,
                beta=.1,
                bias='none',
                task_type='CAUSAL_LM',
                max_prompt_length=4096,
                max_length=4096,
                epochs = 1,
                max_steps = -1,
                lr=5e-4,
                weight_decay=.1,
                per_device_train_batch_size=16,
                per_device_eval_batch_size=32,
                gradient_accumulation_steps=8,
                optim='adamw_torch_fused',
                warmup_ratio=0.03,
                lr_scheduler_type='cosine',
                auto_find_batch_size = True,
                group_by_length=True,
                dataloader_num_workers=2,
                logging_steps=10,
                save_total_limit=3,
                save_strategy='steps',
                save_steps =.1,
                eval_steps=.1,
                load_best_model_at_end=True,
                project_name='DPO_training_dm',
                entity='ft-llmmm',
                torch_compile=False,
                length_column_name='lengths',
                truncation_mode='keep_start',
                repo_id = None,
                output_dir = None,
                hub_strategy = 'every_save'):

    if torch.cuda.get_device_capability()[0] == 8:
        bf16=True,
        fp16=False
    else:
        bf16=False
        fp16=True


    bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

    model = AutoModelForCausalLM.from_pretrained(
            model_id,
            use_cache=False
            if gradient_checkpointing
            else True,  # this is needed for gradient checkpointing
            device_map="auto",
            quantization_config=bnb_config
        )

    model.train()

    model = create_peft_model(model,
                            r=r,
                            lora_alpha=lora_alpha,
                            lora_dropout=lora_dropout,
                            bias=bias,
                            task_type=task_type,
                            gradient_checkpointing=gradient_checkpointing,
                            bf16=bf16)

    tokenizer = AutoTokenizer.from_pretrained(
            model_id,
        )

    tokenizer.pad_token = tokenizer.eos_token

    model_name = model_id.split('/')[-1]

    if output_dir is None:
        output_dir = f'./{model_name}_DPO_{ds_name}_r_{r}_alpha_{lora_alpha}'

    if wandb_token:
        wandb.login(key=wandb_token)

        wandb.init(
            job_type='training',
            project=project_name,
            entity=entity,
            name = repo_id
            )

    training_args = TrainingArguments(
        logging_dir =output_dir+'./logs',
        output_dir= output_dir,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        bf16=bf16,  # Use BF16 if available
        fp16=fp16,
        learning_rate=lr,
        num_train_epochs=epochs,
        max_steps = max_steps,
        gradient_checkpointing=gradient_checkpointing,
        optim=optim,
        warmup_ratio=warmup_ratio,
        weight_decay = weight_decay,
        gradient_accumulation_steps=gradient_accumulation_steps,
        group_by_length=group_by_length,
        # logging strategies
        logging_strategy="steps",
        logging_steps=logging_steps,
        save_strategy=save_strategy,
        evaluation_strategy = save_strategy,
        save_steps = save_steps,
        eval_steps = eval_steps,
        lr_scheduler_type=lr_scheduler_type,
    #   log_level = 'error',
        hub_token=hf_token,
        report_to='wandb' if wandb_token else None,
        dataloader_num_workers = dataloader_num_workers,
        load_best_model_at_end=load_best_model_at_end,
        save_total_limit = save_total_limit,
        remove_unused_columns=False,
        disable_tqdm=False,
        torch_compile=torch_compile,
        length_column_name=length_column_name,
        auto_find_batch_size=auto_find_batch_size,
        push_to_hub = True if repo_id else False,
        hub_strategy=hub_strategy,
        #max_grad_norm=0.3,
        hub_model_id=repo_id
        #max_grad_norm=0.3
    )

    dataset = ds_dict['ds_name']
    train_dataset = dataset['train']
    eval_dataset = dataset['validation']

    dpo_trainer = DPOTrainer(
        model,
        args=training_args,
        beta=beta,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        max_prompt_length=max_prompt_length,
        max_length=max_length,
        truncation_mode=truncation_mode
    )

    original_performance = dpo_trainer.evaluate()
    wandb.log({'initial-performance': wandb.Table(dataframe=pd.DataFrame(original_performance, index=["Performance"]))})

    dpo_trainer.train()

    if repo_id:
        eval_result = dpo_trainer.evaluate()
        dpo_trainer.create_model_card(model_name=repo_id)
        dpo_trainer.push_to_hub()

    #final_performance = dpo_trainer.evaluate()
    #run.log({'final-performance': wandb.Table(dataframe=pd.DataFrame(final_performance, index=["Performance"]))})

    dpo_trainer.save_model(output_dir)

In [None]:
model_id = 'dhmeltzer/llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16_merged'
dataset = ds_RM_top_2_filt[1024]
epochs = 1
optim = 'paged_adamw_8bit'
per_device_train_batch_size=32
per_device_eval_batch_size = 32
gradient_accumulation_steps=4

DPO_training(model_id,
            dataset,
            hf_token = hf_token,
            wandb_token = wandb_token,
            epochs = epochs,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=per_device_eval_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            optim=optim,
            auto_find_batch_size=False,
            repo_id = 'dhmeltzer/llama-7b-SFT-eli5wiki1024-DPO_top2-1024-r64-alpha16')

Downloading (…)lve/main/config.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00002.bin:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/183 [00:00<?, ?B/s]

Found 7 modules to quantize: ['k_proj', 'gate_proj', 'o_proj', 'up_proj', 'q_proj', 'v_proj', 'down_proj']
trainable params: 159,907,840 || all params: 6,898,323,456 || trainable%: 2.3180681656919973


Downloading (…)okenizer_config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdmeltzer[0m ([33mft-llmmm[0m). Use [1m`wandb login --relogin`[0m to force relogin


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Rewards/chosen,Rewards/rejected,Rewards/accuracies,Rewards/margins,Logps/rejected,Logps/chosen,Logits/rejected,Logits/chosen
19,0.6887,0.682096,-0.375533,-0.47066,0.539983,0.095127,-202.567566,-206.854294,0.130077,0.156852
38,0.6852,0.683311,-0.408925,-0.54935,0.552083,0.140424,-203.354477,-207.188248,-0.001801,0.031421
57,0.6899,0.680961,-0.059625,-0.166149,0.566604,0.106524,-199.522461,-203.695221,0.033546,0.075674
76,0.6638,0.67197,-0.325149,-0.486091,0.575126,0.160942,-202.721878,-206.350449,0.030284,0.073753
95,0.6768,0.668999,0.063126,-0.058047,0.590173,0.121173,-198.441437,-202.467697,0.033054,0.084962
114,0.676,0.669399,-0.133009,-0.261908,0.594276,0.128899,-200.480042,-204.429077,0.054214,0.106292
133,0.6703,0.666981,-0.168998,-0.291161,0.593224,0.122163,-200.772583,-204.788971,0.056855,0.109225
152,0.6812,0.664921,-0.118525,-0.249208,0.595013,0.130684,-200.353043,-204.28421,0.047949,0.101074
171,0.6808,0.664863,-0.112347,-0.244159,0.588594,0.131813,-200.302597,-204.222443,0.045132,0.098196




Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

training_args.bin:   0%|          | 0.00/4.22k [00:00<?, ?B/s]

adapter_model.bin:   0%|          | 0.00/320M [00:00<?, ?B/s]

In [None]:
gc.enable()
gc.collect()
torch.cuda.empty_cache()

In [None]:
model_id = 'dhmeltzer/llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16_merged'
dataset = ds_RM_contrast_filt[1024]
epochs = 1
optim = 'paged_adamw_8bit'
per_device_train_batch_size=32
per_device_eval_batch_size = 32
gradient_accumulation_steps=4

DPO_training(model_id,
            dataset,
            hf_token = hf_token,
            wandb_token = wandb_token,
            epochs = epochs,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=per_device_eval_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            optim=optim,
            auto_find_batch_size=False,
            repo_id = 'dhmeltzer/llama-7b-SFT-eli5wiki1024-DPO_contrast-1024-r64-alpha16')

gc.collect()
torch.cuda.empty_cache()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Found 7 modules to quantize: ['k_proj', 'gate_proj', 'o_proj', 'up_proj', 'q_proj', 'v_proj', 'down_proj']
trainable params: 159,907,840 || all params: 6,898,323,456 || trainable%: 2.3180681656919973




0,1
eval/logits/chosen,██▁▃▃▄▅▅▅▅▅
eval/logits/rejected,▇█▁▃▃▃▄▄▄▃▃
eval/logps/chosen,▇▁▁▆▂█▅▅▅▅▅
eval/logps/rejected,█▂▁▆▂▇▅▄▅▅▅
eval/loss,█▅▆▅▃▂▂▂▁▁▁
eval/rewards/accuracies,▁▇▇████████
eval/rewards/chosen,▇▁▁▆▂█▅▅▅▅▅
eval/rewards/margins,▁▅▇▆█▆▇▆▇▇▇
eval/rewards/rejected,█▂▁▆▂▇▅▄▅▅▅
eval/runtime,█▁▂▁▂▁▂▂▂▁▂

0,1
eval/logits/chosen,0.0982
eval/logits/rejected,0.04513
eval/logps/chosen,-204.22244
eval/logps/rejected,-200.3026
eval/loss,0.66486
eval/rewards/accuracies,0.58859
eval/rewards/chosen,-0.11235
eval/rewards/margins,0.13181
eval/rewards/rejected,-0.24416
eval/runtime,153.5012


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669247816662392, max=1.0…

KeyboardInterrupt: ignored

In [None]:
model_id = 'dhmeltzer/llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16_merged'
dataset = ds_RM_random_filt[1024]
epochs = 1
optim = 'paged_adamw_8bit'
per_device_train_batch_size=32
per_device_eval_batch_size = 32
gradient_accumulation_steps=4

DPO_training(model_id,
            dataset,
            hf_token = hf_token,
            wandb_token = wandb_token,
            epochs = epochs,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=per_device_eval_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            optim=optim,
            auto_find_batch_size=False,
            repo_id = 'dhmeltzer/llama-7b-SFT-eli5wiki1024-DPO_random-1024-r64-alpha16')

gc.collect()
torch.cuda.empty_cache()

In [None]:
gradient_checkpointing=True
r=64
lora_alpha=16
lora_dropout=0.1
bias='none'
task_type='CAUSAL_LM'
max_seq_length=512
epochs = 1
max_steps = -1
lr=2e-4
weight_decay=.01
per_device_train_batch_size=1
per_device_eval_batch_size=1
gradient_accumulation_steps=1
optim='paged_adamw_32bit'
warmup_ratio=0.03
group_by_length=True
dataloader_num_workers=2
logging_steps=10
save_total_limit=3
save_strategy='steps'
save_steps =.2
eval_steps=.2
load_best_model_at_end=True
project_name='DPO_training_dm'
entity='ft-llmmm'
torch_compile=False
length_column_name='lengths'

SFT_model_id = 'dhmeltzer/Llama-2-7b-hf-wiki-no-gl-r-64-alpha-16-full'
base_model_id = 'meta-llama/Llama-2-7b-hf'

#SFT_model_id = 'distilgpt2'
#base_model_id = SFT_model_id

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
        SFT_model_id,
        use_cache=False
        if gradient_checkpointing
        else True,  # this is needed for gradient checkpointing
        device_map="auto",
        quantization_config=bnb_config,
        #use_auth_token=hf_token
    )

model = create_peft_model(model,
                          r=r,
                          lora_alpha=lora_alpha,
                          lora_dropout=lora_dropout,
                          bias=bias,
                          task_type=task_type,
                          gradient_checkpointing=gradient_checkpointing,
                          bf16=bf16)

model.train()

tokenizer = AutoTokenizer.from_pretrained(
        base_model_id,
        #use_auth_token=hf_token
    )

tokenizer.pad_token = tokenizer.eos_token



In [None]:
output_dir = f'./SFT_wiki_no_gl_DPO/models'



train_dataset = ds_RM_top_2['train']
eval_dataset = ds_RM_top_2['validation']

In [None]:
del DataCollator
del dpo_trainer

In [None]:
gc.collect()
torch.cuda.empty_cache()