In [None]:
!pip -q install huggingface_hub transformers datasets accelerate \
                evaluate bitandbytes peft deepspeed 

In [None]:
MODEL_HF_PATH = 'facebook/opt-iml-max-1.3b'
REWARD_MODEL_HF_PATH = 'facebook/opt-125m'

BASE_DIR = '.'
DATA_DIR = f'{BASE_DIR}/datasets'
MODEL_DIR = f'{BASE_DIR}/models'

MANUAL_INSTRUCTION_DATASET = f'{DATA_DIR}/manual_instructions.csv'
NUM_WORDS = 256

def split_truncate(sentence: str):
    return " ".join(sentence.split()[:NUM_WORDS])

In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer,
    PreTrainedModel
)
from transformers.trainer_pt_utils import nested_detach
from datasets import load_dataset, concatenate_datasets
from peft import (
    LoraConfig,
    prepare_model_for_kbit_training,
    get_peft_model
)
import torch
import torch.nn as nn
import torch.nn.functional as fn
import torch.utils.data as data_utils
import numpy as np
import random

from typing import List, Dict, Any, Optional, Union, Literal
from dataclasses import dataclass
from collections import deque, defaultdict

# Supervised Fine-tuning

In [None]:
SFT_DATASET_PATH = 'ArmelR/stack-exchange-instruction'
SFT_MODEL_PATH = f'{MODEL_DIR}/sft_opt'

In [None]:
USE_8_BIT_TRAINING = False

if not USE_8_BIT_TRAINING:
    sft_tokenizer = AutoTokenizer.from_pretrained(MODEL_HF_PATH)
    sft_opt = AutoModelForCausalLM.from_pretrained(
        MODEL_HF_PATH,
        device_map='auto',
    )
else:
    sft_tokenizer = AutoTokenizer.from_pretrained(MODEL_HF_PATH)
    sft_opt = AutoModelForCausalLM.from_pretrained(
        MODEL_HF_PATH,
        load_in_8bit=True,
        device_map='auto',
    )
    sft_opt = prepare_model_for_kbit_training(sft_opt)
    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
    )
    sft_opt = get_peft_model(sft_opt, lora_config)
    sft_opt.print_trainable_parameters()

In [None]:
sft_dataset = load_dataset(SFT_DATASET_PATH, split='train')
sft_dataset = sft_dataset.remove_columns(['qid', 'date', 'metadata'])
sft_dataset = sft_dataset.train_test_split(train_size=100000)['train']

instruction_dataset = load_dataset(
    'csv',
    data_files=MANUAL_INSTRUCTION_DATASET,
    split='train'
)
instruction_dataset = instruction_dataset.remove_columns(['rejected'])
instruction_dataset = instruction_dataset.rename_column('chosen', 'response')

final_sft_dataset = concatenate_datasets([sft_dataset, instruction_dataset])
final_sft_dataset = final_sft_dataset.map(
    lambda x: {
        'inputs': (
            f'<Question>: {split_truncate(x["question"])}'
            f'<Answer>: {split_truncate(x["response"])}'
        )
    }
)
final_sft_dataset = final_sft_dataset.map(
    lambda x: {'input_ids': sft_tokenizer(x['inputs'])['input_ids']},
    remove_columns=['inputs'],
    num_proc=4
)
final_sft_dataset = final_sft_dataset.filter(
    lambda x: len(x['input_ids']) <= sft_opt.config.max_position_embeddings,
    num_proc=4
)
final_sft_dataset = final_sft_dataset.train_test_split(test_size=0.1)

In [None]:
final_sft_dataset.to_parquet(f'{DATA_DIR}/sft_dataset.parquet')
final_sft_dataset = load_dataset(
    'parquet',
    data_files=f'{DATA_DIR}/sft_dataset.parquet',
    split='train'
)

In [None]:
sft_trainer = Trainer(
    model=sft_opt,
    args=TrainingArguments(
        output_dir="models",
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=8,
        gradient_checkpointing=True,
        fp16=True,
        optim='paged_adamw_8bit',
        logging_steps=1000,
        save_strategy='no',
        evaluation_strategy='steps',
        eval_steps=5000,
        num_train_epochs=2,
        learning_rate=1e-5,
        deepspeed={
            "zero_optimization": {
                "stage": 2,
                "offload_optimizer": {
                    "device": "cpu",
                    "pin_memory": True
                },
                "allgather_partitions": True,
                "allgather_bucket_size": 5e8,
                "reduce_scatter": True,
                "reduce_bucket_size": 5e8,
                "overlap_comm": True,
                "contiguous_gradients": True
            }
        },
    ),
    train_dataset = final_sft_dataset['train'],
    eval_dataset = final_sft_dataset['test'],
    data_collator = DataCollatorForLanguageModeling(sft_tokenizer, mlm=False),
    tokenizer = sft_tokenizer,
)

In [None]:
sft_trainer.save_model(SFT_MODEL_PATH)

# RLHF

## Reward Model Training

In [None]:
REWARD_DATASET_PATH = f'lvwerra/stack-exchange-paired'
REWARD_MODEL_PATH = f'{MODEL_DIR}/sft_opt'

In [None]:
class RewardDataCollatorWithPadding:
    def __init__(self, tokenizer, max_length: int, return_tensors="pt"):
      self.tokenizer = tokenizer
      self.max_length = max_length
      self.return_tensors = return_tensors

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        max_length = min(
            max(
                max(len(feature['input_ids_chosen']) for feature in features),
                max(len(feature['input_ids_rejected']) for feature in features),
            ),
            self.max_length
        )

        features_chosen = {
            'input_ids': [
                feature['input_ids_chosen'][:max_length]
                for feature in features
            ]
        }
        features_rejected = {
            'input_ids': [
                feature['input_ids_rejected'][:max_length]
                for feature in features
            ]
        }

        batch_chosen = self.tokenizer.pad(
            features_chosen,
            padding='max_length',
            max_length=max_length,
            return_tensors=self.return_tensors,
            return_attention_mask=True
        )
        batch_rejected = self.tokenizer.pad(
            features_rejected,
            padding='max_length',
            max_length=max_length,
            return_tensors=self.return_tensors,
            return_attention_mask=True
        )
        batch = {
            "input_ids_chosen": batch_chosen["input_ids"],
            "attention_mask_chosen": batch_chosen["attention_mask"],
            "input_ids_rejected": batch_rejected["input_ids"],
            "attention_mask_rejected": batch_rejected["attention_mask"],
            "return_loss": True,
        }

        return batch


class RewardTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs = False):
        rewards_chosen = model(
            input_ids=inputs["input_ids_chosen"],
            attention_mask=inputs["attention_mask_chosen"]
        ).logits

        rewards_rejected = model(
            input_ids=inputs["input_ids_rejected"],
            attention_mask=inputs["attention_mask_rejected"]
        ).logits

        loss = - fn.logsigmoid(rewards_chosen - rewards_rejected).mean()
        if return_outputs:
                return loss, {
                    "rewards_chosen": rewards_chosen,
                    "rewards_rejected": rewards_rejected
                }
        return loss

    def prediction_step(
        self,
        model,
        inputs,
        prediction_loss_only,
        ignore_keys=None
    ):
        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        with torch.no_grad():
            loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)

        if prediction_loss_only:
            return (loss, None, None)

        loss = loss.detach()
        logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
        logits = nested_detach(logits)

        # Stack accepted against rejected, mean over logits
        # and softmax to get preferences between accepted and rejected to sum to 1
        logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T

        labels = torch.ones(logits.shape[0])
        # labels = self._prepare_inputs(labels)

        return loss, logits, labels

In [None]:
reward_tokenizer = AutoTokenizer.from_pretrained(REWARD_MODEL_HF_PATH)
reward_opt = AutoModelForSequenceClassification.from_pretrained(
    REWARD_MODEL_HF_PATH,
    device_map='auto',
    num_labels=1,
)

In [None]:
reward_dataset = load_dataset(
    'parquet',
    data_files=REWARD_DATASET_PATH,
    split='train'
)
reward_dataset = reward_dataset.remove_columns(['qid', 'date', 'metadata'])
reward_dataset = reward_dataset.rename_columns(
    {'response_j': 'chosen', 'response_k': 'rejected'}
)
reward_dataset = reward_dataset.train_test_split(train_size=100000)['train']
instruction_dataset = load_dataset(
    'csv',
    data_files=MANUAL_INSTRUCTION_DATASET,
    split='train'
)
final_reward_dataset = concatenate_datasets([reward_dataset, instruction_dataset])
final_reward_dataset = final_reward_dataset.map(
    lambda x: {
        'inputs_chosen': (
            f'<Question>: {split_truncate(x["question"])}'
            f'<Answer>: {split_truncate(x["chosen"])}'
        ),
        'inputs_rejected': (
            f'<Question>: {split_truncate(x["question"])}'
            f'<Answer>: {split_truncate(x["rejected"])}'
        )
    }
)
final_reward_dataset = final_reward_dataset.map(
    lambda x: {
        'input_ids_chosen': reward_tokenizer(x['inputs_chosen'])['input_ids'],
        'input_ids_rejected': reward_tokenizer(x['inputs_rejected'])['input_ids']
    },
    remove_columns=['inputs_chosen', 'inputs_rejected'],
    num_proc=4
)
final_reward_dataset = final_reward_dataset.filter(
    lambda x: (
        len(x['input_ids_chosen']) <= reward_opt.config.max_position_embeddings
        and len(x['input_ids_rejected']) <= reward_opt.config.max_position_embeddings
    )
    num_proc=4
)

In [None]:
final_reward_dataset.to_parquet(f'{DATA_DIR}/reward_dataset.parquet')
final_reward_dataset = load_dataset(
    'parquet',
    data_files=f'{DATA_DIR}/reward_dataset.parquet',
    split='train'
)

In [None]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    accuracy = np.array(preds[:,0] > preds[:,1], dtype = int).mean()
    return {"accuracy": accuracy}

reward_trainer = RewardTrainer(
    model=reward_opt,
    args=TrainingArguments(
        output_dir="models",
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        gradient_accumulation_steps=2,
        gradient_checkpointing=True,
        optim='adamw_8bit',
        logging_steps=1000,
        save_strategy='no',
        evaluation_strategy='steps',
        eval_steps=1000,
        num_train_epochs=2,
        learning_rate=5e-5,
        deepspeed={
            "zero_optimization": {
                "stage": 2,
                "offload_optimizer": {
                    "device": "cpu",
                    "pin_memory": True
                },
                "allgather_partitions": True,
                "allgather_bucket_size": 2e8,
                "reduce_scatter": True,
                "reduce_bucket_size": 2e8,
                "overlap_comm": True,
                "contiguous_gradients": True
            }
        },
        remove_unused_columns=False,
    ),
    train_dataset=final_reward_dataset['train'],
    eval_dataset=final_reward_dataset['test'],
    data_collator=RewardDataCollatorWithPadding(
        reward_tokenizer, max_length=reward_opt.config.max_position_embeddings
    ),
    tokenizer=reward_tokenizer,
    compute_metrics=compute_metrics
)
reward_trainer.train()

In [None]:
reward_trainer.save_model(REWARD_MODEL_PATH)

## Proximal Policy Optimization

In [None]:
PPO_MODEL_PATH = f'{MODEL_DIR}/ppo_opt'

In [None]:
ppo_tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL_PATH)
ppo_opt = AutoModelForCausalLM.from_pretrained(
    SFT_MODEL_PATH,
    device_map='auto',
)
ref_opt = AutoModelForCausalLM.from_pretrained(
    SFT_MODEL_PATH,
    device_map='auto',
    load_in_8bit=True,
)
reward_opt = AutoModelForSequenceClassification.from_pretrained(
    REWARD_MODEL_PATH,
    device_map='auto',
)

In [None]:
final_ppo_dataset = load_dataset(
    'parquet',
    data_files=f'{DATA_DIR}/reward_dataset.parquet',
    split='train'
).train_test_split(train_size=10000)['train']

In [None]:
@dataclass
class RLHFConfig:
    memory_size: int = 16
    sample_size: int = 4
    clip_range: float = 0.2
    value_loss_weight: float = 0.5
    kl_div_weight: float = 0.2
    max_length: int = 400


class RLHFTrainer(Trainer):
    def __init__(
        self,
        reward_model: PreTrainedModel,
        reference_model: PreTrainedModel,
        rlhf_config: RLHFConfig,
        **regular_trainer_kwargs,
    ):
        super().__init__(**regular_trainer_kwargs)

        assert isinstance(regular_trainer_kwargs['model'], PreTrainedModel), 'RLHFTrainer only supports HuggingFace\'s PreTrainedModel'
        assert self.tokenizer is not None, 'RLHFTrainer requires a tokenizer'
        # assert self.tokenizer.padding_side == 'left', 'You should use a left padding to ensure RLHFTrainer correctness'
        assert not getattr(self.model, 'is_encoder_decoder', False), 'RLHFTrainer only supports decoder-only architecture'
        self.tokenizer.padding_size = 'left'
        self.reward_model = reward_model
        self.reference_model = reference_model
        self.rlhf_config = rlhf_config
        self.value_head = self._create_value_head() # Will be discarded after training
        if self.optimizer is not None:
            self.optimizer.add_param_group({'v_head', self.value_head})
        else:
            self.model.add_module('v_head', self.value_head)

        self.memory = deque(maxlen=self.rlhf_config.memory_size)

    def _create_value_head(self):
        summary_dropout_prob = getattr(self.model.config, 'summary_dropout_prob', 0.1)
        hidden_size = (
            getattr(self.model.config, 'word_embed_proj_dim', None)
            or getattr(self.model.config, 'hidden_size', None)
        )
        return nn.Sequential(
            nn.Dropout(summary_dropout_prob),
            nn.Linear(hidden_size, 1, bias=False),
        ).to(self.model.device)

    def _compute_batch(self, model, sentences):
        model_outputs = model(sentences, output_hidden_states=True)

        logit_indexes = sentences[:,1:]

        logits = model_outputs.logits[:,:-1].float()
        probs = torch.gather(
            fn.softmax(logits, dim=-1), # batch, seq_len, num_embedding
            dim=-1,
            index=logit_indexes.unsqueeze(-1)
        ).squeeze(dim=-1)

        hidden_states = model_outputs.hidden_states[-1].float()
        values = self.value_head(hidden_states)[:,:-1].squeeze(dim=-1)

        return logits, probs, values

    def _compute_loss(self, model):
        samples = random.choices(self.memory, k=self.rlhf_config.sample_size)
        (
            sentences,
            logits,
            probs,
            values,
            advantages,
            returns
        ) = data_utils.default_collate(samples)

        new_logits, new_probs, new_values = self._compute_batch(
            model, sentences
        )

        value_losses_1 = (new_values - returns) ** 2
        value_losses_2 = (
            torch.clip(
                new_values,
                values - self.rlhf_config.clip_range,
                values + self.rlhf_config.clip_range
            ) - returns
        ) ** 2
        value_loss = torch.max(value_losses_1, value_losses_2).mean()

        ratios = torch.exp(torch.log(new_probs) - torch.log(probs))
        policy_losses_1 = - advantages * ratios
        policy_losses_2 = - advantages * torch.clip(
            ratios,
            1.0 - self.rlhf_config.clip_range,
            1.0 + self.rlhf_config.clip_range,
        )
        policy_loss = torch.max(policy_losses_1, policy_losses_2).mean()
        return policy_loss + self.rlhf_config.value_loss_weight * value_loss

    def _generate_sentences(self, model, prompts):
        sentences = model.generate(
            prompts,
            max_length=self.rlhf_config.max_length,
            do_sample=True,
            return_dict_in_generate=False
        )
        padded_sentences = self.tokenizer.pad(
            {'input_ids': sentences},
            max_length=self.rlhf_config.max_length,
            padding='max_length'
        )['input_ids']
        return padded_sentences.to(model.device)

    def compute_loss(self, model, inputs, return_outputs=False):
        sentences = self._generate_sentences(model, inputs['input_ids'])

        with torch.no_grad():
            logits, probs, values = self._compute_batch(
                model, sentences
            )
            _, ref_probs, _ = self._compute_batch(
                self.reference_model, sentences
            )
            rewards = self.reward_model(sentences).logits
            rewards = rewards.repeat(1, logits.shape[1])

            kl_div = (torch.log(probs) - torch.log(ref_probs))
            rewards = rewards - kl_div

            advantages = fn.normalize(rewards - values)
            returns = advantages + values

        for data in zip(
            sentences, logits, probs, values, advantages, returns
        ):
            self.memory.append(data)

        return self._compute_loss(model)


ppo_trainer = RLHFTrainer(
    model=ppo_opt,
    args=TrainingArguments(
        output_dir="models",
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        gradient_accumulation_steps=2,
        gradient_checkpointing=True,
        optim='paged_adamw_8bit',
        logging_steps=500,
        save_strategy='no',
        evaluation_strategy='steps',
        eval_steps=500,
        num_train_epochs=1,
        learning_rate=5e-5,
        deepspeed={
            "zero_optimization": {
                "stage": 2,
                "offload_optimizer": {
                    "device": "cpu",
                    "pin_memory": True
                },
                "allgather_partitions": True,
                "allgather_bucket_size": 2e8,
                "reduce_scatter": True,
                "reduce_bucket_size": 2e8,
                "overlap_comm": True,
                "contiguous_gradients": True
            }
        },
    ),
    reward_model=reward_opt,
    reference_model=ref_opt,
    rlhf_config=RLHFConfig(
        memory_size=16,
        sample_size=16,
        max_length=512
    ),
    tokenizer=ppo_tokenizer,
    train_dataset=final_ppo_dataset
)
ppo_trainer.train()

In [None]:
ppo_trainer.save_model(PPO_MODEL_PATH)

## Direct Preference Optimization

In [None]:
DPO_MODEL_PATH = f'{MODEL_DIR}/dpo_opt'

In [None]:
dpo_tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL_PATH)
dpo_opt = AutoModelForCausalLM.from_pretrained(
    SFT_MODEL_PATH,
    device_map='auto',
)
ref_opt = AutoModelForCausalLM.from_pretrained(
    SFT_MODEL_PATH,
    device_map='auto',
    load_in_8bit=True,
)

In [None]:
dpo_dataset = load_dataset(
    'parquet',
    data_files=REWARD_DATASET_PATH,
    split='train'
)
dpo_dataset = dpo_dataset.remove_columns(['qid', 'date', 'metadata'])
dpo_dataset = dpo_dataset.rename_columns(
    {'response_j': 'chosen', 'response_k': 'rejected'}
)
dpo_dataset = dpo_dataset.train_test_split(train_size=10000)['train']
instruction_dataset = load_dataset(
    'csv',
    data_files=MANUAL_INSTRUCTION_DATASET,
    split='train'
)
final_dpo_dataset = concatenate_datasets([dpo_dataset, instruction_dataset])
final_dpo_dataset = final_dpo_dataset.map(
    lambda x: {
        'prompt': f'<Question>: {x["question"]}<Answer>: '
    },
    num_proc=4,
    remove_columns=['question']
)

In [None]:
class DPODataCollatorWithPadding:
    def __init__(
        self,
        tokenizer,
        max_length: int = 512,
        max_prompt_length: int = 256,
        padding_value: int = 0,
        label_padding_value: int = - 100, # torch cross entropy ignore index
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.max_prompt_length = max_prompt_length
        self.padding_value = padding_value
        self.label_padding_value = label_padding_value

    def _ignore_eos_attention_mask(self, input_ids, attention_mask):
        eos_token_positions = set(
            i for i, x in enumerate(input_ids)
            if x == self.tokenizer.eos_token_id
        )
        new_attention_mask = [
            0 if i in eos_token_positions else mask
            for i, mask in enumerate(attention_mask)
        ]
        return new_attention_mask

    def _truncate_to_length(self, sequence, max_length):
        return {k: v[: max_length] for k,v in sequence.items()}

    def _create_sequence_with_labels(self, prompt_tokens, response_tokens):
        new_tokens = {k: prompt_tokens[k] + response_tokens[k] for k in prompt_tokens.keys()}
        prompt_len = len(prompt_tokens['input_ids'])
        new_tokens['labels'] = new_tokens['input_ids'].copy()
        new_tokens['labels'][: prompt_len] = [self.label_padding_value] * prompt_len
        return new_tokens

    def tokenize_batch_element(
        self,
        prompt: str,
        chosen: str,
        rejected: str,
    ) -> Dict:
        chosen_tokens = self.tokenizer(chosen, add_special_tokens=False)
        rejected_tokens = self.tokenizer(rejected, add_special_tokens=False)
        prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)

        eos_token_id = self.tokenizer.eos_token_id
        for tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
            tokens['attention_mask'] = self._ignore_eos_attention_mask(
                tokens['input_ids'], tokens['attention_mask']
            )

        longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

        # if combined sequence is too long, truncate the prompt
        if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
            prompt_tokens = self._truncate_to_length(prompt_tokens, self.max_prompt_length)

        # if that's still too long, truncate the response
        if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
            max_response_length = self.max_length - self.max_prompt_length
            chosen_tokens = self._truncate_to_length(chosen_tokens, max_response_length)
            rejected_tokens = self._truncate_to_length(rejected_tokens, max_response_length)

        chosen_sequence_tokens = self._create_sequence_with_labels(prompt_tokens, chosen_tokens)
        rejected_sequence_tokens = self._create_sequence_with_labels(prompt_tokens, rejected_tokens)

        batch = {}
        for token_type, tokens in zip(
            ['prompt', 'chosen', 'rejected'],
            [prompt_tokens, chosen_sequence_tokens, rejected_sequence_tokens]
        ):
            for key in ['input_ids', 'attention_mask', 'labels']:
                if key in tokens:
                    batch[f'{token_type}_{key}'] = tokens[key]

        return batch

    def collate(self, batch):
        padded_batch = {}
        keys = batch[0].keys()
        for key, key_batch in zip(keys, zip(*[item.values() for item in batch])):
            padded_batch[key] = nn.utils.rnn.pad_sequence(
                [torch.tensor(key_item) for key_item in key_batch],
                batch_first=True,
                padding_value=(
                    self.label_padding_value
                    if key.endswith('_labels')
                    else self.padding_value
                )
            )

        return padded_batch

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        tokenized_batch = []

        for feature in features:
            prompt = feature["prompt"]
            chosen = feature["chosen"]
            rejected = feature["rejected"]

            batch_element = self.tokenize_batch_element(prompt, chosen, rejected)
            tokenized_batch.append(batch_element)

        return self.collate(tokenized_batch)

In [None]:
def pad_batch(tensor, length, padding_value, dim=-1):
    padding_shape = list(tensor.shape)
    padding_shape[dim] = length - padding_shape[dim]

    if padding_shape[dim] <= 0:
        return tensor

    padding_tensor = torch.full(padding_shape, padding_value, device=tensor.device)
    return torch.cat([tensor, padding_tensor], dim=dim)

class DPOTrainer(Trainer):
    def __init__(
        self,
        model: PreTrainedModel,
        ref_model: PreTrainedModel,
        data_collator: DPODataCollatorWithPadding,
        beta: float = 0.1,
        disable_dropout: bool = True,
        **regular_trainer_kwargs,
    ):
        if disable_dropout:
            for module in [*model.modules(), *ref_model.modules()]:
                if isinstance(module, torch.nn.Dropout):
                    module.p = 0
        super().__init__(model=model, data_collator=data_collator, **regular_trainer_kwargs)
        self.beta = beta
        self.ref_model = self.accelerator.prepare_model(ref_model, evaluation_mode=True)
        self._stored_metrics = defaultdict(lambda: defaultdict(list))

    def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]):
        concatenated_batch = {}
        max_length = max(
            batch["chosen_input_ids"].shape[1],
            batch["rejected_input_ids"].shape[1]
        )

        for key in ['input_ids', 'attention_mask', 'labels']:
            padding_value = (
                self.data_collator.label_padding_value
                if key == 'labels'
                else self.data_collator.padding_value
            )
            concatenated_batch[f'concatenated_{key}'] = torch.cat(
                [
                    pad_batch(batch[f'chosen_{key}'], max_length, padding_value),
                    pad_batch(batch[f'rejected_{key}'], max_length, padding_value)
                ],
                dim=0
            ).to(self.accelerator.device)

        return concatenated_batch

    def dpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
    ):
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        ref_logratios = reference_chosen_logps - reference_rejected_logps

        logits = pi_logratios - ref_logratios

        losses = - fn.logsigmoid(self.beta * logits)
        chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

        return losses, chosen_rewards, rejected_rewards

    def _get_batch_logps(
        self,
        logits: torch.FloatTensor,
        labels: torch.LongTensor,
    ) -> torch.FloatTensor:
        labels = labels[:, 1:].clone()
        logits = logits[:, :-1, :]
        loss_mask = labels != self.data_collator.label_padding_value

        # dummy token; we'll ignore the losses on these tokens later
        labels[labels == self.data_collator.label_padding_value] = 0

        per_token_logps = torch.gather(
            logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
        ).squeeze(dim=2)

        return (per_token_logps * loss_mask).sum(dim=-1)

    def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ):
        concatenated_batch = self.concatenated_inputs(batch)
        len_chosen = len(batch["chosen_labels"])

        all_logits = model(
            concatenated_batch["concatenated_input_ids"],
            attention_mask=concatenated_batch["concatenated_attention_mask"],
        ).logits.to(torch.float32)

        all_logps = self._get_batch_logps(
            all_logits,
            concatenated_batch["concatenated_labels"],
        )

        chosen_logps = all_logps[:len_chosen]
        rejected_logps = all_logps[len_chosen:]

        chosen_logits = all_logits[:len_chosen]
        rejected_logits = all_logits[len_chosen:]

        return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)

    def get_batch_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        metrics = {}

        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
        ) = self.concatenated_forward(model, batch)

        with torch.no_grad():
            (
                reference_chosen_logps,
                reference_rejected_logps,
                _,
                _,
            ) = self.concatenated_forward(self.ref_model, batch)

        losses, chosen_rewards, rejected_rewards = self.dpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps,
        )
        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().numpy().mean()
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().numpy().mean()
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().numpy().mean()
        metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().numpy().mean()
        metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().numpy().mean()
        metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean()
        metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().numpy().mean()
        metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean()

        return losses.mean(), metrics

    def compute_loss(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs=False,
    ):
        loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")

        # force log the metrics
        if self.accelerator.is_main_process:
            self.store_metrics(metrics, train_eval="train")

        if return_outputs:
            return (loss, metrics)
        return loss

    def prediction_step(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ):
        if ignore_keys is None:
            if hasattr(model, "config"):
                ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        with torch.no_grad():
            loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval")

        # force log the metrics
        if self.accelerator.is_main_process:
            self.store_metrics(metrics, train_eval="eval")

        if prediction_loss_only:
            return (loss.detach(), None, None)

        # logits for the chosen and rejected samples from model
        logits_dict = {
            "eval_logits/chosen": metrics["eval_logits/chosen"],
            "eval_logits/rejected": metrics["eval_logits/rejected"],
        }

        logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
        logits = torch.stack(logits).mean(axis=1)
        labels = torch.zeros(logits.shape[0])

        return (loss.detach(), logits, labels)

    def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
        for key, value in metrics.items():
            self._stored_metrics[train_eval][key].append(value)

    def log(self, logs: Dict[str, float]) -> None:
        # logs either has 'loss' or 'eval_loss'
        train_eval = "train" if "loss" in logs else "eval"
        # Add averaged stored metrics to logs
        for key, metrics in self._stored_metrics[train_eval].items():
            logs[key] = torch.tensor(metrics).mean().item()
        del self._stored_metrics[train_eval]
        return super().log(logs)

In [None]:
dpo_trainer = DPOTrainer(
    model=dpo_opt,
    ref_model=ref_opt,
    beta=0.1,
    train_dataset=final_dpo_dataset,
    tokenizer=dpo_tokenizer,
    data_collator=DPODataCollatorWithPadding(
        tokenizer=dpo_tokenizer,
        max_length=200,
        max_prompt_length=50,
        padding_value=dpo_tokenizer.pad_token_id,
        label_padding_value=-100,
    ),
    args=TrainingArguments(
        output_dir="models",
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=16,
        gradient_checkpointing=True,
        learning_rate=5e-5,
        fp16=True,
        optim='paged_adamw_8bit',
        logging_steps=50,
        save_strategy='no',
        evaluation_strategy='no',
        num_train_epochs=1,
        remove_unused_columns=False,
        deepspeed={
            "zero_optimization": {
                "stage": 2,
                "offload_optimizer": {
                    "device": "cpu",
                    "pin_memory": True
                },
                "allgather_partitions": True,
                "allgather_bucket_size": 2e8,
                "reduce_scatter": True,
                "reduce_bucket_size": 2e8,
                "overlap_comm": True,
                "contiguous_gradients": True
            }
        },
    ),

)
dpo_trainer.train()

In [None]:
dpo_trainer.save_model(DPO_MODEL_PATH)