# GRPO Pytorch Implementaion

Author: *xiaodongguaAIGC*

Github: *dhcode-cpp*

# define basic function

## config

In [1]:
import os
import warnings
warnings.filterwarnings('ignore')  # Suppress all other warnings
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'  # Suppress transformer warnings

In [2]:
vocab_size = 32
hidden_size = 256
intermediate_size = 512
num_hidden_layers = 2
num_attention_heads = 4
num_key_value_heads = 4
batch_size = 2
length_x = 10
max_new_tokens = 10
grpo_samples_nums = 3 # GRPO 采样数量
# final_answer_token = vocab_size / 2 

In [3]:
## pretrained model

In [4]:
import torch
import torch.nn.functional as F
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification

torch.manual_seed(1)

# 加载模型
config = LlamaConfig(vocab_size=vocab_size,      # default is 32000
                     hidden_size=hidden_size,
                     intermediate_size=intermediate_size,
                     num_hidden_layers=num_hidden_layers,
                     num_attention_heads=num_attention_heads,
                     num_key_value_heads=num_key_value_heads,
                     )
model = LlamaForCausalLM(config)
model.config.pad_token_id = model.config.eos_token_id

In [5]:
model_ref = LlamaForCausalLM(config)

## 格式化函数

In [6]:
# 定义特殊token
DEFINE_THINK_START = 25
DEFINE_THINK_END = 26
DEFINE_ANSWER_START = 27
DEFINE_ANSWER_END = 28

In [7]:
def format_prompt(question_token_ids):
    '''
    即给定的参考需要有<think> 和 <\think> 标签，答案由<answer> 和 <\answer> 标签包裹，限定里面只有一个token
    '''
    example = [DEFINE_THINK_START, 2,3,4, DEFINE_THINK_END, DEFINE_ANSWER_START, 7 ,DEFINE_ANSWER_END ] 
    # 我们在格式增加一个<think> 标签来开启 CoT采样
    # <think>...<\think><answer>....<\answer>question <think>
    format_question = example + question_token_ids + [DEFINE_THINK_START] 
    return format_question

In [8]:
# 输入token id 限定在 <25 id号
X = [[1, 5, 8, 3, 4, 18, 10, 12, 20, 11],
    [6, 9, 1, 7, 4, 21, 10, 15, 4, 23]]
Y = [[3], [7]]

In [9]:
test_format = format_prompt(X[0])
print(test_format)

### GRPO rejection sampling

In [10]:
def grpo_rejection_sampling(model, x, max_new_tokens = 10,):
    idx = {'input_ids': x}  # ignore mask
    y = model.generate(**idx,
                       max_new_tokens = max_new_tokens,
                       do_sample = True,
                       )
    return y

# 问题
print(X[0])

# 格式化问题 
input_x = format_prompt(X[0])
print(input_x)

# GRPO采样输入 = 格式化问题 * GRPO采样次数
input_x = [input_x] * grpo_samples_nums
print(input_x)

# GRPO采样输入tensor
input_x_tensor = torch.tensor(input_x, dtype = torch.long)
print(input_x_tensor.shape)

# GRPO采样输出tensor
grpo_xy = grpo_rejection_sampling(model, input_x_tensor, max_new_tokens)
print(grpo_xy)

### reward function

In [11]:
def rule_reward(response, label_id):
    for i in range(len(response) - 2):
        if response[i] == DEFINE_ANSWER_START and response[i + 1] == label_id and response[i + 2] == DEFINE_ANSWER_END:
            return True
    return False
    
def think_reward(response):
    found_one = False
    for num in response:
        if num == DEFINE_THINK_START:
            found_one = True
        elif num == DEFINE_THINK_END:
            if found_one:
                return True
    return False

# reponse 有 <answer> 包裹
result = rule_reward(grpo_xy[0].tolist(), 4)
print(result)

result = rule_reward(grpo_xy[0].tolist(), 7)
print(result)  # 输出: True

len_input_x = len(input_x[0])
result = rule_reward(grpo_xy[0, len_input_x-1:].tolist(), 7)
print(result)  

print(grpo_xy[0, :])
print(grpo_xy[0, len_input_x:])

## 一定要出现有think
result = think_reward(grpo_xy[0].tolist())
print(result)

len_input_x = len(input_x[0])
result = think_reward(grpo_xy[0, len_input_x-1:].tolist())
print(result)  

print(grpo_xy[0, :])
print(grpo_xy[0, len_input_x:])

# GRPO 训练流程

1. 批量采样
2. 计算奖励
3. 计算loss

## 批量采样

一个question 对应 多条回答

In [12]:
def GRPO_batch_rejection_sample(inputs, nums, max_new_tokens = 10):
    # 编程实现可以把 prompt进行left_padding
    # padding: p, input: x, output: y, mask: m
    # 1:  p p p p p p x x x | y y y y y p p p
    # 1:  m m m m m m m m m |           m m m
    # 2:  x x x x x x x x x | y y y y y y y p
    # 1:  m m m m m m m m m |               m
    grpo_xy_batch = []
    grpo_x_len = []
    for input in inputs:
        format_inputs = [format_prompt(input)] * nums
        format_input_len = len(format_inputs[0])
        grpo_x_len.append(format_input_len)
        input_x_tensor = torch.tensor(format_inputs, dtype = torch.long)
        grpo_xy = grpo_rejection_sampling(model, input_x_tensor, max_new_tokens)
        grpo_xy_batch.append(grpo_xy) # list[ groups, len ]
    return grpo_xy_batch, grpo_x_len


grpo_xy_batch, grpo_x_len = GRPO_batch_rejection_sample(X, grpo_samples_nums, max_new_tokens=max_new_tokens)

print(grpo_xy_batch)
print(grpo_xy_batch[0].shape)

## 批量计算奖励

In [13]:
def GRPO_batch_reward(X, inputs_responses, labels):
    batch_rewards = []
    # print(len(labels))
    # print(len(inputs_responses))
    # print(len(X))
    for input, inputs_response, label in zip(X, inputs_responses, labels):
        len_prompt = len(input) - 1
        responses = inputs_response.tolist() # 一组回答
        rewards = [ rule_reward(response, label[0]) for response in responses]
        batch_rewards.append(rewards)
    return batch_rewards

print(X, Y)
batch_rewards = GRPO_batch_reward(X, grpo_xy_batch, Y)
print(batch_rewards)

## 批量GRPO

1. advantage
2. KL
3. GRPO

### GRPO Advantage

$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$  

In [14]:
def grpo_advantage(rewards):
    epsilon = 0.0001
    rewards = torch.tensor(rewards, dtype = torch.float) 
    A = (rewards - rewards.mean()) / (rewards.std() + epsilon)
    return A

advantage = grpo_advantage(batch_rewards[0])
print(advantage)

1. advantage描述相对性估计，全对全错，优化可以skip，这种情况过多，将难以进行RL优化
2. case2里，越少正例，advantage越大
3. advantage有正负
4. reward为{0,1}之中的一个值, advantage为浮点数
5. 有越多的采样，估计的advantage越准确

In [15]:
A = grpo_advantage([0,0,0,0,0,0])
print(A)

A = grpo_advantage([1,0,0,0,0,0])
print(A)

A = grpo_advantage([1,0,0,0,1,0])
print(A)

A = grpo_advantage([1,0,0,1,1,0])
print(A)

A = grpo_advantage([1,1,1,1,1,1])
print(A)

In [16]:
reward_batch = [0] * 64
reward_batch[0] = 1
A_64 = grpo_advantage(reward_batch)
print(A_64)

### GRPO KL

ref：Schulman et al. (2020)](http://joschu.net/blog/kl-approx.html)

$$\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - 1,
$$

具体分析可以看`./notebook/grpo/GRPO_KL.ipynb`

In [17]:
def grpo_kl(pi_logprob, pi_ref_logprob):
    return pi_ref_logprob.exp() / pi_logprob.exp()- (pi_ref_logprob - pi_logprob) - 1

pi = torch.randn(3, 5)
pi_ref = torch.randn(3, 5)
pi_logprob = torch.nn.functional.log_softmax(pi, dim = 1)
pi_ref_logprob = torch.nn.functional.log_softmax(pi_ref, dim = 1)
print(grpo_kl(pi_logprob, pi_ref_logprob))

### GRPO Loss

ref: [TRL:GRPO](https://huggingface.co/docs/trl/main/en/grpo_trainer#trl.GRPOTrainer)

$$
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
$$

$$
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
$$

where  \\(\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) \\) ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between  \\( 1 - \epsilon \\) and  \\( 1 + \epsilon \\).
In TRL though, as in the original paper, we only do one update per generation, so we can simplify the loss to the first form.

#### how to get sample "action" policy?

In [18]:
import torch

# logits或者logbrob  :[batch_size, sequence_length, vocab_size]
x = torch.randn(2, 3, 5)  # 形状为(2, 3, 5)

# 创建索引tensor: [batch_size, sequence_length]
# 注意：索引的维度比原tensor少1维，因为我们在最后一维gather
indices = torch.tensor([
    [0, 2, 1],  # 第一个batch的索引
    [4, 3, 4]   # 第二个batch的索引
])

# 在最后一维进行gather
# dim=-1 等价于 dim=2
output = torch.gather(x, dim=-1, index=indices.unsqueeze(-1)).squeeze(-1)

print(f"输入形状: {x.shape}")        # torch.Size([2, 3, 5])
print(f"索引形状: {indices.shape}")   # torch.Size([2, 3])
print(f"输出形状: {output.shape}")    # torch.Size([2, 3])


#### GRPO loss from logits

In [19]:

def grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, advantage, input_len):
    epslion = 0.2
    beta = 0.01

    bs, seq_len = pi_logprob.shape

    advantage = advantage.unsqueeze(dim = 1) # [a, b ,c] -> [[a], [b], [c]]

    ratio = torch.exp(pi_logprob - pi_old_logprob)
    ratio_clip = torch.clamp(ratio, 1 - epslion, 1 + epslion)

    policy_gradient = torch.minimum(ratio * advantage , ratio_clip * advantage)
    kl = grpo_kl(pi_logprob, pi_ref_logprob)

    # skip计算采样的每条采样长度
    group_num, len_oi = pi_logprob.shape  # 其中每个oi长度实际上要按照非pad token进行计算
    # 比如group有三个rollout, 长度 |oi| = [234, 11, 56]
    # 我们更方便实现为
    len_oi = len_oi - input_len
    len_oi = torch.tensor([len_oi] * group_num, dtype = torch.long)

    # 设定mask, 仅对response 为 1， 算loss
    mask = torch.zeros(bs,seq_len)
    mask[:, input_len:] = 1

    loss = (policy_gradient -  beta * kl) * mask
    loss = (- 1 / group_num ) * loss / len_oi.unsqueeze(dim = 1)
    loss = loss.sum()

    return loss
    

pi_logits = torch.randn(3, 5, 32)
pi_ref_logits = torch.randn(3, 5, 32)
pi_old_logits = torch.randn(3, 5, 32)


pi_logprob = torch.nn.functional.log_softmax(pi_logits, dim = -1)
pi_ref_logprob = torch.nn.functional.log_softmax(pi_ref_logits, dim = -1)
pi_old_logprob = torch.nn.functional.log_softmax(pi_old_logits, dim = -1)

# 假设输入长度为3，
token_ids = torch.tensor([[11, 12, 13, 14, 15], # 输入为11,12,13, 输出为:14, 15
                          [11,12,13,15, 16],
                          [11,12,13,16, 17],])

pi_logprob = torch.gather(pi_logprob, dim=-1, index=token_ids.unsqueeze(-1)).squeeze(-1)
pi_ref_logprob = torch.gather(pi_ref_logprob, dim=-1, index=token_ids.unsqueeze(-1)).squeeze(-1)
pi_old_logprob = torch.gather(pi_old_logprob, dim=-1, index=token_ids.unsqueeze(-1)).squeeze(-1)

loss = grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, advantage, 3)
print(loss)

# 完整算法

## GRPO input & label

In [20]:
X = [
    [1, 5, 8, 3, 4, 18, 10, 12, 20, 11],
    [6, 9, 13, 7, 4, 21, 10, 15, 4,  23],
    [3, 5, 14, 6, 10,20, 4,  9,  10, 15],
    [6,19,17, 5,16, 21, 10, 20, 13, 19]
]
X = torch.tensor(X, dtype = torch.long)

Y = [[3], [7], [12], [4]]

Y = torch.tensor(Y, dtype = torch.long)

## GRPO Dataset

In [21]:
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {
            'input': self.data[idx],
            'label': self.labels[idx]
        }

dataset = MyDataset(data = X, labels = Y)
data_loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
)

## GRPO on-policy sampling step and optimizer step

In [22]:
import torch.optim as optim

model = LlamaForCausalLM(config)
optimizer = optim.Adam(model.parameters(), lr= 0.000001)

epochs = 10
grpo_epochs = 10
for i in range(epochs):
    for batch in data_loader:

        # STEP: on-policy sampling, prepare training data
        input = batch['input']
        label = batch['label']

        input = input.tolist()

        # GRPO采样
        grpo_xy_batch, grpo_x_len = GRPO_batch_rejection_sample(input, 
                                                    grpo_samples_nums, 
                                                    max_new_tokens=max_new_tokens)
        # print(grpo_xy_batch)

        # GRPO reward
        batch_rewards = GRPO_batch_reward(input, grpo_xy_batch, label)
        # print(batch_rewards)

        # GRPO Advantage
        batch_advantage = []
        for group_rewards in batch_rewards:
            A = grpo_advantage(group_rewards)
            batch_advantage.append(A)
        # print(batch_advantage)


        # # GRPO Prepare training data
        # # for i in batch_size
        pi_old_logprob_list = []
        pi_ref_logprob_list = []
        for grpo_xy in grpo_xy_batch:
            with torch.no_grad():
                # print(grpo_xy)
                old_policy_logits = model(grpo_xy).logits
                ref_policy_logits = model_ref(grpo_xy).logits

            pi_old_logprob = torch.nn.functional.log_softmax(old_policy_logits, dim = -1)
            pi_old_logprob = torch.gather(pi_old_logprob, dim=-1, index=grpo_xy.unsqueeze(-1)).squeeze(-1)
            pi_old_logprob_list.append(pi_old_logprob)
            
            pi_ref_logprob = torch.nn.functional.log_softmax(ref_policy_logits, dim = -1)
            pi_ref_logprob = torch.gather(pi_ref_logprob, dim=-1, index=grpo_xy.unsqueeze(-1)).squeeze(-1)
            pi_ref_logprob_list.append(pi_ref_logprob)

        
        # STEP: Training
        # current sample target model logprob for policy gradient 
        for k in range(grpo_epochs):
            # 可在增加mini-batch，类似PPO
            total_loss = 0
            for pi_old_logprob, pi_ref_logprob, advantage, x_len, grpo_xy in zip(pi_old_logprob_list, 
                                                                         pi_ref_logprob_list, 
                                                                         batch_advantage, 
                                                                         grpo_x_len,
                                                                         grpo_xy_batch):
                grpo_policy_logits = model(grpo_xy).logits 
                pi_grpo_logprob = torch.nn.functional.log_softmax(grpo_policy_logits, dim = -1) # target logprob
                pi_grpo_logprob = torch.gather(pi_grpo_logprob, dim=-1, index=grpo_xy.unsqueeze(-1)).squeeze(-1)
                
                loss = grpo_loss(pi_grpo_logprob, pi_old_logprob, pi_ref_logprob, advantage, x_len - 1)
                total_loss += loss
                
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
    print(i, total_loss)

thx: julian lou for debug