## CS310 Natural Language Processing
## Lab 13: Human Alignment

In this lab, we will practice two tasks:
- Using the code framework for training a reward model that assigns scores to pairs of sentences. 
- Getting familiar with the code framework for Direct Preference Optimization (DPO).


In [82]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import LlamaForCausalLM

## T1. Defining Reward Model


We will use the [LlamaForCausalLM](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaForCausalLM) model from HuggingFace, as the basis for our reward model.

First, two internal forward functions are to be implemented:
- `_forward_rm`: it takes the input ids and attention masks of a sequence (user input + response), and returns the reward scores.
  - The reward scores are in tensor of same shape as the input ids, with **one reward score for each token**.
  - Reward scores are calculated by calling a linear layer `self.reward_head` on the last hidden state (of the entire sequence).
- `_forward_lmloss`: it takes the input of same format, but returns the regular language modeling loss.
  - Logits are computed by calling `self.lm_head` on the last hidden state.
  - The `response_ids` are used as the target for the `nn.CrossEntropyLoss()`.

Then, define the `forward` function, which takes the input ids and attention masks of two sequences, and returns the combined loss.
- Compute `reward1` on the first sequence (positve example) and `reward2` on the second sequence (negative example).
- Calculate their difference in `logits`
- Reward loss is computed by calling `F.binary_cross_entropy_with_logits(logits, label)`.

In [83]:
class LlamaRewardModel(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)

        # A linear layer to map hidden states to a scalar, as the final reward
        self.reward_head = nn.Linear(config.hidden_size, 1, bias=False)

    def _forward_rm(self, input_ids, attention_mask, **kargs):
        """
        input_ids: input token ids
        attention_mask: attention mask
        Return: reward scores, output from self.reward_head
        """
        # Call self.model.forward()  to get the hidden states
        output = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask, 
            return_dict=True,
            use_cache=False
        )
        ### START YOUR CODE ###
        # Feed the last hidden state from output to self.reward_head to get the reward score
        last_hidden_state = output.last_hidden_state
        rewards = self.reward_head(last_hidden_state)
        ### END YOUR CODE ###

        return rewards 
    
    def _forward_lmloss(self, prompt_ids, lm_attn_mask, response_ids):
        """
        input_ids: input token ids
        attention_mask: attention mask
        Return: cross-entropy loss for language modeling
        """ 
        # Call self.model.forward()  to get the hidden states
        outputs = self.model.forward(
            input_ids=prompt_ids,
            attention_mask=lm_attn_mask,
            return_dict=True,
            use_cache=False,
        )

        ### START YOUR CODE ###
         # 获取最后一层的隐藏状态
        last_hidden_state = outputs.last_hidden_state
        # 通过lm_head得到logits
        logits = self.lm_head(last_hidden_state)
        
        # 计算交叉熵损失
        criterion = nn.CrossEntropyLoss()
        # 将logits和response_ids调整为正确的形状
        logits = logits.view(-1, logits.size(-1))
        response_ids = response_ids.view(-1)
        loss = criterion(logits, response_ids)
        ### END YOUR CODE ###

        return loss
        
    def forward(self, sent1_idx, attention_mask_1, sent2_idx, attention_mask_2, labels, prompt_ids, lm_attn_mask, response_ids, **kargs):
        """
        sent1_idx: User input ids + positive output ids
        attention_mask_1: Attention mask for sent1_idx
        sent2_idx: User input ids + negative output ids
        attention_mask_2: Attention mask for sent2_idx

        labels: Positive output ids (all zeros)

        prompt_ids: User input ids + positive output ids
        lm_attn_mask: Attention mask for prompt_ids
        response_ids: Target ids for calculating cross-entropy loss
        """

        ### START YOUR CODE ###
        # Reward for positive example
        reward0 = self._forward_rm(sent1_idx, attention_mask_1)
        # Reward for negative example
        reward1 = self._forward_rm(sent2_idx, attention_mask_2)
        # Calculate the reward difference
        logits = reward0 - reward1  # Shape: [batch_size, seq_len, 1]
        
        # Squeeze the last dimension to match labels shape
        logits = logits.squeeze(-1)  # Shape: [batch_size, seq_len]
        ### END YOUR CODE ###

        # Compute the reward modeling loss
        rm_loss = F.binary_cross_entropy_with_logits(logits, labels.to(logits.dtype), reduction="mean")

        # Compute the language modeling loss 
        lm_loss = self._forward_lmloss(prompt_ids, lm_attn_mask, response_ids)

        # Final loss
        loss = rm_loss + lm_loss

        return loss

In [84]:
# Test
#model = LlamaRewardModel.from_pretrained('/Users/xy/models/llama-2-7b-hf')
model = LlamaRewardModel.from_pretrained('./qwen')
# model = LlamaRewardModel.from_pretrained(
#     "Qwen/Qwen-7B", 
#     revision="main",  
#     trust_remote_code=True
# )
# You expect to see the model correctly initialized

You are using a model of type qwen2 to instantiate a model of type llama. This is not supported for all configurations of models and can yield errors.
Some weights of LlamaRewardModel were not initialized from the model checkpoint at ./qwen and are newly initialized: ['reward_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## T2. Load Preference Data

We will load the preference dataset from `Anthropic/hh-rlhf` for testing.

In [85]:
from dataclasses import dataclass
from datasets import load_dataset
from transformers import TrainingArguments, AutoTokenizer
from transformers.hf_argparser import HfArg

In [86]:
@dataclass
class Arguments(TrainingArguments):
    model_name_or_path: str = HfArg(
        default="./qwen", # The path to your model
        help="The model name or path"
    )
    
    # Preference dataset
    data_path: str = HfArg(
        default='./hh-rlhf', # The path to the preference dataset
        help="The path of preference dataset, e.g., `Anthropic/hh-rlhf`",
    )

    model_max_length: int = HfArg(default=512, help="Maximum sequence length.")

    bf16: bool = HfArg(
        default=True,
        help="Whether to use bf16 (mixed) precision instead of 32-bit.",
    )

    # Hyper-parameters for DPO loss
    beta: float = HfArg(
        default=0.1,
        help="The beta factor in DPO loss."
        "Higher beta means less divergence from the initial policy.",
    )

    output_dir: str = HfArg(
        default="output",
        help="The output directory where the model predictions and checkpoints will be written.",
    )

In [87]:
# Test
args = Arguments()
print(args.model_name_or_path)
print(args.data_path)

./qwen
./hh-rlhf


The following function prepares the preference dataset in a user-friendly view.

In [88]:
def get_data(split, data_path):
    dataset = load_dataset(split=split, path=data_path)

    def split_prompt_and_responses_hh(sample):
        search_term = "\n\nAssistant:"
        search_term_idx = sample["chosen"].rfind(search_term)
        assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
        prompt = sample["chosen"][:search_term_idx + len(search_term)]
        return {
            "prompt": prompt,
            "chosen": sample["chosen"][len(prompt):],
            "rejected": sample["rejected"][len(prompt):],
        }

    return dataset.map(split_prompt_and_responses_hh)

In [89]:
# Test
train_dataset = get_data('train', args.data_path)

prompt = train_dataset[0]['prompt']
chosen = train_dataset[0]['chosen']
rejected = train_dataset[0]['rejected']
print(prompt[:45])
print(chosen)
print(rejected)



Human: What are some cuss words in english?
 I haven't even thought about it.
 Ass.


Now, load tokenizer and tokenize some sample data.

- `sent1_encoded` is the tokenized result of `prompt + chosen` (positive example)
- `sent2_encoded` is the tokenized result of `prompt + rejected` (negative example)

In [90]:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False)


### START YOUR CODE ###
sent1_encoded = tokenizer(
    prompt + chosen,
    truncation=True,
    max_length=args.model_max_length,
    return_tensors="pt"
)

sent2_encoded = tokenizer(
    prompt + rejected,
    truncation=True,
    max_length=args.model_max_length,
    return_tensors="pt"
)

### END YOUR CODE ###

Pad two sequences (input ids and attention masks) to same length

In [91]:
sent1_idx = sent1_encoded['input_ids']
sent2_idx = sent2_encoded['input_ids']

# Pad input ids
max_len = max(sent1_idx.shape[1], sent2_idx.shape[1])
sent1_idx = torch.nn.functional.pad(sent1_idx, (0, max_len - sent1_idx.shape[1]), value=tokenizer.pad_token_id)
sent2_idx = torch.nn.functional.pad(sent2_idx, (0, max_len - sent2_idx.shape[1]), value=tokenizer.pad_token_id)

# Pad attention masks
sent1_attn_mask = sent1_encoded['attention_mask']
sent2_attn_mask = sent2_encoded['attention_mask']
sent1_attn_mask = torch.nn.functional.pad(sent1_attn_mask, (0, max_len - sent1_attn_mask.shape[1]), value=0)
sent2_attn_mask = torch.nn.functional.pad(sent2_attn_mask, (0, max_len - sent2_attn_mask.shape[1]), value=0)

print(sent1_idx.shape)
print(sent2_idx.shape)
print(sent1_attn_mask.shape)
print(sent2_attn_mask.shape)

torch.Size([1, 185])
torch.Size([1, 185])
torch.Size([1, 185])
torch.Size([1, 185])


Prepare input data

In [92]:
input_data = {
    'sent1_idx': sent1_idx, 
    'attention_mask_1': sent1_attn_mask, 
    'sent2_idx': sent2_idx, 
    'attention_mask_2': sent2_attn_mask, 

    'labels': torch.zeros_like(sent1_idx), 

    'prompt_ids': sent1_encoded['input_ids'], 
    'lm_attn_mask': sent1_encoded['attention_mask'], 
    'response_ids': sent1_encoded['input_ids'],
}

In [93]:
with torch.no_grad():
    output = model(**input_data)
    print(output)

# You expect to see a single loss value
# Runtime Error is likely to because by the implementation of the internal forward functions
# You can use the following code to help you debug
# r1 = model._forward_rmloss(sent1_idx, sent1_attn_mask)
# print(r1.shape)

tensor(17.3644)


## T3. (Optional) DPO Training

You need to install the [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/en/index) library first.

```bash
pip install trl
```

In [94]:
from trl import DPOTrainer
from transformers import AutoModelForCausalLM, HfArgumentParser

In [95]:
def train():
    # Parse arguments
    parser = HfArgumentParser(Arguments)
    args = parser.parse_args_into_dataclasses()[0]
    
    # Load policy model
    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
    # Load reference model
    model_ref = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
    # Freeze reference model
    model_ref.eval()
    for param in model_ref.parameters():
        param.requires_grad = False

    # Tokenizer and data
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        model_max_length=args.model_max_length,
        padding_side="right",
        add_eos_token=True,
    )
    train_dataset = get_data("train", args.data_path)

    # Training arguments
    kwargs = dict(
        model=model,
        ref_model=model_ref,
        args=args,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
    )

    # Start training
    dpo_trainer = DPOTrainer(**kwargs)
    dpo_trainer.train()
    dpo_trainer.save_state()

In [96]:
train()

usage: ipykernel_launcher.py [-h] [--output_dir OUTPUT_DIR]
                             [--overwrite_output_dir [OVERWRITE_OUTPUT_DIR]]
                             [--do_train [DO_TRAIN]] [--do_eval [DO_EVAL]]
                             [--do_predict [DO_PREDICT]]
                             [--eval_strategy {no,steps,epoch}]
                             [--prediction_loss_only [PREDICTION_LOSS_ONLY]]
                             [--per_device_train_batch_size PER_DEVICE_TRAIN_BATCH_SIZE]
                             [--per_device_eval_batch_size PER_DEVICE_EVAL_BATCH_SIZE]
                             [--per_gpu_train_batch_size PER_GPU_TRAIN_BATCH_SIZE]
                             [--per_gpu_eval_batch_size PER_GPU_EVAL_BATCH_SIZE]
                             [--gradient_accumulation_steps GRADIENT_ACCUMULATION_STEPS]
                             [--eval_accumulation_steps EVAL_ACCUMULATION_STEPS]
                             [--eval_delay EVAL_DELAY]
                         

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
