In [1]:
from typing import Optional, Tuple, Union

import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM
import torch
from torch import nn
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

## 奖励模型

In [2]:
reward_tokenizer = AutoTokenizer.from_pretrained("OpenAssistant/reward-model-deberta-v3-base")
reward_model = AutoModelForSequenceClassification.from_pretrained("OpenAssistant/reward-model-deberta-v3-base")

device = torch.device('mps')
reward_model = reward_model.to(device)
reward_model.eval()

DebertaV2ForSequenceClassification(
  (deberta): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128100, 768, padding_idx=0)
      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-11): 12 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=768, out_features=768, bias=True)
              (key_proj): Linear(in_features=768, out_features=768, bias=True)
              (value_proj): Linear(in_features=768, out_features=768, bias=True)
              (pos_dropout): Dropout(p=0.1, inplace=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): DebertaV2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): Layer

In [3]:
prompt = "Please tell me how to learn machine learning."
response1 = "Learning machine learning can start with mathematical foundations, then learn programming, followed by learning core algorithms, and through project practice."
response2 = "I really like to eat chocolate."
response3 = "Learning machine learning can start with eat apples, then oranges, followed by jumping into a swimming pool."

responses = [response1, response2, response3]
prompts = [prompt] * 3
encoded_input = reward_tokenizer(
    prompts,
    responses,
    truncation=True,
    padding=True,
    return_tensors="pt", # 返回 PyTorch 张量
    max_length = 512
).to(device)
print(encoded_input.keys())
print(encoded_input['input_ids'].tolist()[0])

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
[1, 863, 848, 351, 361, 264, 799, 1494, 1101, 260, 2, 4735, 1494, 1101, 295, 564, 275, 10378, 10355, 261, 393, 799, 4050, 261, 1708, 293, 1101, 2233, 8785, 261, 263, 390, 663, 1105, 260, 2]


In [4]:
scores = reward_model(**encoded_input)

In [5]:
scores.logits.squeeze(-1)

tensor([ 0.0531, -5.3392, -4.6690], device='mps:0', grad_fn=<SqueezeBackward1>)

## Actor-Critic 网络

In [6]:
from transformers.modeling_outputs import CausalLMOutputWithPast

class PolicyWithValueHeadWrapper(nn.Module):
    """
    A wrapper around a Causal Language Model to add a Value Head.
    Mirrors the forward signature needed for generation and PPO loss.
    """
    def __init__(self, base_model: AutoModelForCausalLM):
        super().__init__()
        self.base_model = base_model
        # Get the hidden size from the base model's config
        config = base_model.config
        if hasattr(config, 'hidden_size'):
            hidden_size = config.hidden_size
        elif hasattr(config, 'n_embd'):
            hidden_size = config.n_embd
        else:
             raise ValueError(f"Cannot find hidden size in config: {config}")
        self.v_head = nn.Linear(hidden_size, 1)
        nn.init.normal_(self.v_head.weight, std=0.01)
        nn.init.zeros_(self.v_head.bias)
    # Define forward method matching the base model's common signature
    # This signature is crucial for compatibility with generate and PPO logic
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, # Use standard tuple hint
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # **kwargs # Allow for other potential arguments from different models
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[Tuple[torch.Tensor]]]]:
        """
        Args:
            input_ids: Input token IDs.
            attention_mask: Attention mask.
            past_key_values: KV cache from previous steps.
            use_cache: Whether to return updated KV cache.
            ... other args passed to the base model.
        Returns:
            logits: Logits from the language model head (batch_size, seq_len, vocab_size).
            values: Value predictions from the value head (batch_size, seq_len).
            past_key_values: Updated KV cache.
        """
        # Ensure output_hidden_states is True to get necessary input for v_head
        # Call base model forward
        outputs: CausalLMOutputWithPast = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=True,
        )

        last_hidden_state = outputs.hidden_states[-1]
        values_full_seq_chunk = self.v_head(last_hidden_state).squeeze(-1)
        return outputs.logits, values_full_seq_chunk, outputs.past_key_values

In [7]:
# 加载模型和分词器
model_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(tokenizer.pad_token)
if tokenizer.pad_token is None:
    tokenizer.pad_token = '[pad]'

base_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
policy_model = PolicyWithValueHeadWrapper(base_model)
policy_model.to(device)
policy_model

<|endoftext|>


PolicyWithValueHeadWrapper(
  (base_model): Qwen3ForCausalLM(
    (model): Qwen3Model(
      (embed_tokens): Embedding(151936, 1024)
      (layers): ModuleList(
        (0-27): 28 x Qwen3DecoderLayer(
          (self_attn): Qwen3Attention(
            (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
            (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
            (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
            (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
          )
          (mlp): Qwen3MLP(
            (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
            (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
            (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
            (act_fn): SiLU()
          )
          (input_la

## 数据收集
从 Prompt 数据集中获取一批 Prompts，使用当前的 policy_model 为每个 Prompt 生成 Responses。

In [8]:
from transformers import DynamicCache

def apply_top_k_filter(logits: torch.Tensor, top_k: int):
    if top_k <= 0:
        return logits

    top_k_values, _ = torch.topk(logits, top_k, dim=-1)
    kth_value = top_k_values[:, -1].unsqueeze(-1)
    indices_to_remove = logits < kth_value
    logits[indices_to_remove] = -float('inf')
    return logits

def apply_top_p_filtering(logits: torch.Tensor, top_p: float):
    if top_p >= 1.0:
        return logits

    sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
    sorted_probs = nn.functional.softmax(sorted_logits, dim=-1)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    indices_to_remove = cumulative_probs > top_p
    indices_to_remove[..., 1:] = indices_to_remove[..., :-1].clone()
    indices_to_remove[..., 0] = False
    mask_in_original_order = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
    mask_in_original_order.scatter_(dim=-1, index=sorted_indices, src=indices_to_remove)
    logits[mask_in_original_order] = -float('inf')
    return logits

In [9]:
from transformers import DynamicCache
from torch import nn

def generate_response_with_probs_values_batch(model: nn.Module, tokenizer, prompts, device, max_new_tokens=50, temperature=1.0, top_k=None, top_p=None, do_sample=True):
    model.eval()
    batch_size = len(prompts)

    tokenized_prompts = tokenizer(prompts, return_tensors="pt", padding=True, truncation=False)
    input_ids = tokenized_prompts.input_ids.to(device)
    attention_mask = tokenized_prompts.attention_mask.to(device)

    prompt_lengths = attention_mask.sum(dim=1)
    max_prompt_length = input_ids.size(1)
    max_total_length = max_prompt_length + max_new_tokens

    response_ids = torch.full((batch_size, max_new_tokens), tokenizer.pad_token_id, dtype=torch.long, device=device)
    log_probs = torch.full((batch_size, max_new_tokens), 0.0, dtype=torch.float, device=device)
    batch_values_padded = torch.full((batch_size, max_new_tokens + 1), 0.0, dtype=torch.float, device=device)

    is_finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
    sequence_lengths = prompt_lengths.clone()
    current_response_length = 0 # Number of response tokens generated so far (0 to max_new_tokens-1)
    past_key_values = None

    with torch.no_grad():
        # First call: Process the full prompt
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=None,
            use_cache=True
        )
        # logits (bs, prompt_len, vs), values_chunk (bs, prompt_len), past_kv (DynamicCache)
        logits, values_chunk_prompt, past_key_values = outputs

    batch_values_padded[:, 0] = values_chunk_prompt[:, -1] # Shape (batch_size,)
    last_token_logits = logits[torch.arange(batch_size, device=device), prompt_lengths - 1, :] # (1, vocab_size)

    for step in range(max_new_tokens):
        if torch.all(is_finished):
            break

        current_logits = last_token_logits

        if temperature != 1.0:
            if temperature < 0:
                temperature = 1.0
            current_logits = current_logits / temperature

        log_probs_dist = nn.functional.log_softmax(current_logits, dim=-1)  # (batch_size, vocab_size)

        if do_sample:
            filtered_logits = current_logits.clone()
            if top_k is not None and top_k > 0:
                 filtered_logits = apply_top_k_filter(filtered_logits, top_k)
            if top_p is not None and top_p < 1.0:
                 filtered_logits = apply_top_p_filtering(filtered_logits, top_p)
            probs = torch.softmax(filtered_logits, dim=-1)  # (batch_size, vocab_size)
            next_tokens = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
        else:
            next_tokens = torch.argmax(current_logits, dim=-1).unsqueeze(-1)

        sampled_token_log_probs = log_probs_dist.gather(1, next_tokens).squeeze(1) # (batch_size,)

        active_sequences = ~is_finished

        response_ids[active_sequences, current_response_length] = next_tokens[active_sequences].squeeze(1)
        log_probs[active_sequences, current_response_length] = sampled_token_log_probs[active_sequences]

        sequence_lengths[active_sequences] += 1

        # 生成是否结束的mask
        current_tokens_are_eos = (next_tokens.squeeze(1) == tokenizer.eos_token_id)
        reached_max_len = (sequence_lengths >= max_total_length)
        is_finished = is_finished | current_tokens_are_eos | reached_max_len

        # 生成下一个seq_len前向传播的数据
        current_input_ids_for_next_step = next_tokens
        current_total_length_for_mask = max_prompt_length + step + 1
        new_attention_mask = torch.zeros(batch_size, current_total_length_for_mask, dtype=torch.long, device=device)
        for i in range(batch_size):
             new_attention_mask[i, :sequence_lengths[i]] = 1

        with torch.no_grad():
            # 调用 policy_model 并获取 logits 和 value
            outputs = model(
                input_ids=current_input_ids_for_next_step,
                attention_mask=new_attention_mask,
                past_key_values=DynamicCache.from_legacy_cache(past_key_values),
                use_cache=True
            )
            logits, values_chunk_step, past_key_values = outputs

        batch_values_padded[:, step + 1] = values_chunk_step.squeeze(1)
        last_token_logits = logits.squeeze(1)

        current_response_length += 1

    # 获取生成数据
    collected_data = {
        'prompts': [],
        'responses': [],
        'full_sequence_ids': [],
        'response_ids': [],
        'log_probs': [],
        'values': [],
        'actual_response_length': []
    }
    for i in range(batch_size):
        prompt_ids = tokenizer(prompts[i], return_tensors="pt").input_ids.to(device)

        seq_generated_response_ids = response_ids[i]
        actual_response_length = (seq_generated_response_ids != tokenizer.pad_token_id).sum().item() # Count non-pad tokens

        actual_response_ids = seq_generated_response_ids[:actual_response_length]
        full_sequence_ids = torch.cat([prompt_ids.squeeze(0), actual_response_ids], dim=0)
        response_text = tokenizer.decode(actual_response_ids, skip_special_tokens=True)

        collected_data['prompts'].append(prompts[i])
        collected_data['responses'].append(response_text)
        collected_data['full_sequence_ids'].append(full_sequence_ids)
        collected_data['response_ids'].append(actual_response_ids)
        collected_data['log_probs'].append(log_probs[i][:actual_response_length])
        collected_data['values'].append(batch_values_padded[i, :actual_response_length + 1])
        collected_data['actual_response_length'].append(actual_response_length)
    return collected_data


In [10]:
test_prompts = [
        "写一个关于机器学习的短介绍。",
        "如何用Python实现一个简单的线性回归模型？",
    ]

rollout_data_batch = generate_response_with_probs_values_batch(
    policy_model, tokenizer, test_prompts, device, max_new_tokens=200, temperature=1.0, top_k=5, top_p=0.8
)

## 收集奖励

In [11]:
ref_base_model = AutoModelForCausalLM.from_pretrained(model_name)
ref_model = PolicyWithValueHeadWrapper(ref_base_model).to(device)

In [12]:
def pad_list_of_tensors(list_of_tensors, pad_value, dtype=None, device=None):
    if not list_of_tensors:
        return torch.empty(0, dtype=dtype, device=device), torch.empty(0, dtype=torch.bool, device=device)

    max_len = max(t.size(0) for t in list_of_tensors)
    padded_tensors = []
    masks = []

    if dtype is None:
        dtype = list_of_tensors[0].dtype
    if device is None:
        device = list_of_tensors[0].device

    for t in list_of_tensors:
        pad_len = max_len - t.size(0)
        padded_t = nn.functional.pad(t, (0, pad_len), value=pad_value)
        padded_tensors.append(padded_t)

        mask = torch.ones(t.size(0), dtype=torch.bool, device=device)
        mask = nn.functional.pad(mask, (0, pad_len), value=False)
        masks.append(mask)

    return torch.stack(padded_tensors), torch.stack(masks)

def get_reward_scores(data, ref_model, reward_model, tokenizer, reward_tokenizer, kl_beta=0.01, device=None):
    batch_size = len(data['prompts'])

    prompts = data['prompts']
    response_texts = data['responses']
    list_full_sequence_ids = data['full_sequence_ids']
    list_response_ids = data['response_ids']
    list_policy_log_probs = data['log_probs']
    list_values = data['values']
    actual_response_lengths = data['actual_response_length']

    # 1.对收集数据进行pad
    # pad full_sequence_id
    padded_full_sequence_ids, attention_mask_full_sequence = pad_list_of_tensors(
        list_full_sequence_ids,
        pad_value=tokenizer.pad_token_id,
        dtype=torch.long,
        device=device
    )

    # pad response_id
    padded_response_ids, response_mask = pad_list_of_tensors(
        list_response_ids,
        pad_value=tokenizer.pad_token_id,
        dtype=torch.long,
        device=device
    )
    max_response_len = padded_response_ids.size(1)

    # pad policy_log_probs
    padded_policy_log_probs, _ = pad_list_of_tensors(
        list_policy_log_probs,
        pad_value=0.0,
        dtype=torch.float,
        device=device
    )

    # pad values
    # padded_values.size(1) = max_response_len + 1 : 对prompt的估计作为第1个元素
    padded_values, _ = pad_list_of_tensors(
        list_values,
        pad_value=0.0,
        dtype=torch.float,
        device=device
    )

    actual_prompt_lengths = torch.tensor([len(list_full_sequence_ids[i]) - actual_response_lengths[i] for i in range(batch_size)], device=device)

    # 2.计算奖励分数
    reward_model_inputs = [p + r for p, r in zip(prompts, response_texts)]
    tokenized_reward_inputs = reward_tokenizer(
        reward_model_inputs,
        return_tensors='pt',
        padding=True,
        truncation=True
    ).to(device)

    with torch.no_grad():
        reward_outputs = reward_model(**tokenized_reward_inputs)
        reward_scores = reward_outputs.logits.squeeze(-1)

    # 3.计算KL散度
    with torch.no_grad():
        ref_outputs = ref_model(
            input_ids=padded_full_sequence_ids,
            attention_mask=attention_mask_full_sequence,
            return_dict=True
        )
        ref_logits = ref_outputs[0]

    response_step_indices = torch.arange(max_response_len, device=device).unsqueeze(0).expand(batch_size, -1) # Shape: (batch_size, max_response_len)

    expanded_prompt_lengths = actual_prompt_lengths.unsqueeze(-1).expand(-1, max_response_len)
    context_indices_in_full_logits = expanded_prompt_lengths + response_step_indices - 1
    context_indices_expanded = context_indices_in_full_logits.unsqueeze(-1).expand(-1, -1, ref_logits.size(-1))

    # 保留response部分的logits
    ref_logits_for_response_context = torch.gather(ref_logits, dim=1, index=context_indices_expanded)
    ref_log_probs_dist = nn.functional.log_softmax(ref_logits_for_response_context, dim=-1)

    # 得到response_ids部分对应的的logits
    ref_log_probs_generated = ref_log_probs_dist.gather(2, padded_response_ids.unsqueeze(-1)).squeeze(-1)

    kl_per_token = padded_policy_log_probs - ref_log_probs_generated
    masked_kl_per_token = kl_per_token * response_mask.float()
    total_kl_divergence = masked_kl_per_token.sum(dim=-1)

    # 4.最终总奖励
    rewards = reward_scores - kl_beta * total_kl_divergence

    return {
        "rewards": rewards,                    # Shape: (batch_size,)
        "padded_values": padded_values,              # Shape: (batch_size, max_response_len + 1)
        "padded_policy_log_probs": padded_policy_log_probs,    # Shape: (batch_size, max_response_len)
        "padded_response_ids": padded_response_ids,        # Shape: (batch_size, max_response_len)
        "padded_full_sequence_ids": padded_full_sequence_ids,   # Shape: (batch_size, max_total_len)
        "attention_mask_full_sequence": attention_mask_full_sequence, # Shape: (batch_size, max_total_len)
        "response_mask": response_mask               # Shape: (batch_size, max_response_len)
}

In [13]:
res = get_reward_scores(rollout_data_batch, ref_model, reward_model, tokenizer, reward_tokenizer, device=device)
print(res['rewards'])

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


tensor([-7.7583, -6.3136], device='mps:0')


## 计算GAE

In [14]:
def calculate_gae(rewards, padded_values, response_mask, device, gamma=0.99, lambda_=0.95):
    batch_size, max_response_len = response_mask.shape

    rewards = rewards.to(device)
    padded_values = padded_values.to(device)
    response_mask = response_mask.to(device)

    rewards_padded_per_step = torch.zeros(batch_size, max_response_len, dtype=torch.float, device=device)

    actual_response_lengths = response_mask.sum(dim=1)
    valid_indices = actual_response_lengths > 0
    if torch.any(valid_indices):
        last_valid_response_indices = actual_response_lengths[valid_indices] - 1
        batch_indices = torch.arange(batch_size, device=device)[valid_indices]
        rewards_padded_per_step[batch_indices, last_valid_response_indices] = rewards[valid_indices]

    # 优势计算
    delta = rewards_padded_per_step + gamma * padded_values[:, 1:] - padded_values[:, :-1]
    advantages = torch.zeros_like(delta, device=device)
    last_gae_lambda = torch.zeros(batch_size, device=device)

    for t in reversed(range(max_response_len)):
        delta_t = delta[:, t]
        masked_delta_t = delta_t * response_mask[:, t].float()
        advantages_t = masked_delta_t + gamma * lambda_ * last_gae_lambda
        advantages[:, t] = advantages_t
        last_gae_lambda = advantages_t

    advantages = advantages * response_mask.float()
    returns = advantages + padded_values[:, :-1]
    return advantages, returns

In [15]:
advantages, returns = calculate_gae(res['rewards'], res['padded_values'], res['response_mask'], device)
print(advantages)
print(returns)

tensor([[-5.7968e-01, -9.6657e-02,  6.4990e-01,  7.2113e-01, -3.9091e-01,
          3.4005e-02, -4.1875e-01, -1.3563e-01, -3.6192e-01,  3.7891e-02,
         -2.3264e-01,  3.5953e-01,  3.5520e-01,  1.0672e-02,  2.7089e-01,
          1.3855e-01,  1.4101e-01, -3.0462e-01,  1.4752e-01,  8.3769e-02,
          1.0854e-01, -1.2197e-01, -3.2931e-01, -5.7190e-01,  1.8100e-01,
         -2.0163e-01,  4.6049e-02, -3.2317e-01, -8.7065e-02,  1.2481e-01,
          1.8563e-01, -1.7840e-01,  6.1262e-01, -3.3524e-01,  3.9890e-01,
          1.5589e-01, -6.8010e-02, -1.8405e-01,  2.1685e-01,  3.4398e-02,
          1.5487e-01, -5.3448e-01, -6.0253e-02,  1.0317e-01,  2.4802e-01,
         -3.1722e-01,  2.7089e-01, -4.1476e-01, -2.0222e-01, -1.8442e-01,
         -4.2875e-03, -1.5608e-01, -4.1056e-01, -1.2294e-02, -2.5024e-02,
          1.8066e-01, -2.7647e-01, -1.0028e-01, -6.6354e-01, -3.4547e-01,
          8.4479e-02, -2.2000e-01,  1.2214e-01, -1.4676e-01,  3.2425e-01,
         -7.1166e-02,  1.5201e-01,  3.

## 计算PPO损失

In [16]:
def calculate_ppo_loss(
        policy_model,
        padded_policy_log_probs,
        padded_response_ids,
        padded_full_sequence_ids,
        attention_mask_full_sequence,
        response_mask,
        advantages,
        returns,
        clip_param,
        vf_coef,
        ent_coef,
        device
):
    policy_model.train()

    batch_size, max_response_len = response_mask.shape
    max_total_len = padded_full_sequence_ids.size(1)

    max_prompt_length = padded_full_sequence_ids.size(1) - padded_response_ids.size(1)
    policy_model.to(device)

    # 1.计算模型更新后的logits和values
    new_logits_full, new_values_full, _ = policy_model(
        input_ids=padded_full_sequence_ids.to(device),
        attention_mask=attention_mask_full_sequence.to(device)
    )
    new_logits_for_response = new_logits_full[:, max_prompt_length - 1 : max_prompt_length - 1 + max_response_len, :]
    new_log_softmax = nn.functional.log_softmax(new_logits_for_response, dim=-1)
    new_policy_log_probs = new_log_softmax.gather(2, padded_response_ids.to(device).unsqueeze(-1)).squeeze(-1)

    # 2,PPO策略损失
    ratio = torch.exp(new_policy_log_probs - padded_policy_log_probs)
    advantages = advantages.to(device)
    advantages_flat = advantages[response_mask.to(device)]
    if advantages_flat.numel() > 0:
        advantages_norm = (advantages - advantages_flat.mean()) / (advantages_flat.std() + 1e-8)
    else:
         advantages_norm = torch.zeros_like(advantages, device=device)

    advantages_norm = advantages_norm * response_mask.float().to(device)

    surr1 = ratio * advantages_norm
    surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages_norm
    policy_loss_per_token = -torch.min(surr1, surr2)

    # 根据长度归一化
    num_non_padded_tokens = response_mask.sum()
    if num_non_padded_tokens > 0:
        policy_loss = (policy_loss_per_token * response_mask.float().to(device)).sum() / num_non_padded_tokens
    else:
        policy_loss = torch.tensor(0.0, device=device)

    # 3.Value loss
    value_predictions = new_values_full[:, max_prompt_length - 1: max_prompt_length - 1 + max_response_len]
    value_targets= returns.to(device)
    value_error = value_predictions - value_targets
    value_loss_per_token = value_error.pow(2)

    if num_non_padded_tokens > 0:
        value_loss = (value_loss_per_token * response_mask.float().to(device)).sum() / num_non_padded_tokens
    else:
        value_loss = torch.tensor(0.0, device=device)

    # 4.交叉熵奖励
    probs = torch.exp(new_log_softmax)
    entropy_per_token = -(probs * torch.log(probs + 1e-8)).sum(dim=-1)
    if num_non_padded_tokens > 0:
        entropy = (entropy_per_token * response_mask.float().to(device)).sum() / num_non_padded_tokens
    else:
        entropy = torch.tensor(0.0, device=device)

    total_loss = policy_loss + vf_coef * value_loss - ent_coef * entropy
    return total_loss, policy_loss, value_loss, entropy

In [17]:
new_res = res.copy()
new_res.pop('rewards')
new_res.pop('padded_values')
calculate_ppo_loss(
    policy_model, advantages=advantages, returns=returns, clip_param=0.2, vf_coef=0.5,
    ent_coef=0.05, device=device, **new_res
)


(tensor(1.2228, device='mps:0', grad_fn=<SubBackward0>),
 tensor(0.0340, device='mps:0', grad_fn=<DivBackward0>),
 tensor(2.4647, device='mps:0', grad_fn=<DivBackward0>),
 tensor(0.8727, device='mps:0', grad_fn=<DivBackward0>))

## 训练数据集

In [18]:
from torch.utils.data import Dataset, DataLoader


class PPODataset(Dataset):
    def __init__(self,
                 padded_policy_log_probs: torch.Tensor,
                 padded_response_ids: torch.Tensor,
                 padded_full_sequence_ids: torch.Tensor,
                 attention_mask_full_sequence: torch.Tensor,
                 response_mask: torch.Tensor,
                 advantages: torch.Tensor,
                 returns: torch.Tensor):
        self.n_samples = padded_policy_log_probs.size(0)
        self.padded_policy_log_probs = padded_policy_log_probs
        self.padded_response_ids = padded_response_ids
        self.padded_full_sequence_ids = padded_full_sequence_ids
        self.attention_mask_full_sequence = attention_mask_full_sequence
        self.response_mask = response_mask
        self.advantages = advantages
        self.returns = returns

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return (
            self.padded_policy_log_probs[idx],
            self.padded_response_ids[idx],
            self.padded_full_sequence_ids[idx],
            self.attention_mask_full_sequence[idx],
            self.response_mask[idx],
            self.advantages[idx],
            self.returns[idx],
        )

In [19]:
import random

def get_prompt_batch(all_prompts, batch_size):
    if not all_prompts:
        return []
    return random.sample(all_prompts, min(batch_size, len(all_prompts)))

## 训练循环

In [20]:
import time


def train_ppo(
        policy_model: PolicyWithValueHeadWrapper,
        ref_model: AutoModelForCausalLM,
        reward_model: AutoModelForSequenceClassification,
        tokenizer: AutoTokenizer,
        reward_tokenizer: AutoTokenizer,
        all_prompts_dataset: list[str],
        device: torch.device,
        epochs: int = 10,
        rollout_batch_size: int = 16,
        ppo_epochs_per_rollout: int = 4,
        minibatch_size: int = 4,
        learning_rate: float = 5e-6,
        gamma: float = 0.99,
        lambda_: float = 0.95,
        clip_param: float = 0.2,
        vf_coef: float = 0.5,
        ent_coef: float = 0.005,
        kl_beta: float = 0.1,
        max_new_tokens_rollout: int = 50,
        temperature: float = 0.7,
        top_k: Optional[int] = 50,
        top_p: Optional[float] = 0.9,
        do_sample: bool = True,
):
    optimizer = torch.optim.Adam(policy_model.parameters(), lr=learning_rate)
    for epoch in range(epochs):
        epoch_start_time = time.time()
        print(f"\n--- Epoch {epoch + 1}/{epochs} ---")

        # 收集数据阶段
        policy_model.eval()
        current_prompts = get_prompt_batch(all_prompts_dataset, rollout_batch_size)
        print(current_prompts)
        all_rollout_data = generate_response_with_probs_values_batch(
            policy_model,
            tokenizer,
            current_prompts,
            device=device,
            max_new_tokens=max_new_tokens_rollout,
            do_sample=do_sample,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )
        rollout_end_time = time.time()
        print(f"  Rollout 完成. 收集到 {len(all_rollout_data['responses'])} 条有效数据 (耗时: {rollout_end_time - epoch_start_time:.2f}s)。")

        print(f"  开始计算 Rewards 和 GAE...")
        reward_gae_start_time = time.time()
        reward_dict = get_reward_scores(
            all_rollout_data, # Pass raw data
            ref_model,
            reward_model,
            tokenizer,
            reward_tokenizer,
            kl_beta=kl_beta,
            device=device,
        )
        (
            rewards_batch, # Shape: (batch_size,)
            padded_values_batch, # Shape: (batch_size, max_response_len + 1)
            padded_policy_log_probs_batch, # Shape: (batch_size, max_response_len)
            padded_response_ids_batch, # Shape: (batch_size, max_response_len)
            padded_full_sequence_ids_batch, # Shape: (batch_size, max_total_len)
            attention_mask_full_sequence_batch, # Shape: (batch_size, max_total_len)
            response_mask_batch, # Shape: (batch_size, max_response_len)
        ) = (
            reward_dict['rewards'],
            reward_dict['padded_values'],
            reward_dict['padded_policy_log_probs'],
            reward_dict['padded_response_ids'],
            reward_dict['padded_full_sequence_ids'],
            reward_dict['attention_mask_full_sequence'],
            reward_dict['response_mask'],
        )

        reward_gae_step1_end_time = time.time()
        print(f"  Rewards 数据准备完成 (耗时: {reward_gae_step1_end_time - reward_gae_start_time:.2f}s)。")

        # 计算优势和回报
        advantages_batch, returns_batch = calculate_gae(
            rewards_batch, # Shape: (batch_size,)
            padded_values_batch, # Shape: (batch_size, max_response_len + 1)
            response_mask_batch, # Shape: (batch_size, max_response_len)
            gamma=gamma,
            lambda_=lambda_,
            device=device
        )
        reward_gae_end_time = time.time()
        print(f"  GAE 计算完成 (耗时: {reward_gae_end_time - reward_gae_step1_end_time:.2f}s, 总计: {reward_gae_end_time - reward_gae_start_time:.2f}s)。")
        avg_rollout_reward = rewards_batch.mean().item() if rewards_batch.numel() > 0 else 0.0
        print(f"  平均 Rollout Reward: {avg_rollout_reward:.4f}")

        # PPO优化
        optimization_start_time = time.time()

        valid_indices = (response_mask_batch.sum(dim=1) > 0).cpu()
        num_valid_sequences = valid_indices.sum().item()

        ppo_dataset = PPODataset(
            padded_policy_log_probs=padded_policy_log_probs_batch[valid_indices].cpu(),
            padded_response_ids=padded_response_ids_batch[valid_indices].cpu(),
            padded_full_sequence_ids=padded_full_sequence_ids_batch[valid_indices].cpu(),
            attention_mask_full_sequence=attention_mask_full_sequence_batch[valid_indices].cpu(),
            response_mask=response_mask_batch[valid_indices].cpu(),
            advantages=advantages_batch[valid_indices].cpu(),
            returns=returns_batch[valid_indices].cpu(),
        )
        ppo_dataloader = DataLoader(ppo_dataset, batch_size=minibatch_size, shuffle=True)
        total_minibatches = len(ppo_dataloader)
        if total_minibatches == 0:
            continue

        for ppo_epoch in range(ppo_epochs_per_rollout):
            policy_model.train() #

            total_loss_epoch = 0
            policy_loss_epoch = 0
            value_loss_epoch = 0
            entropy_epoch = 0
            minibatch_count = 0

            ppo_epoch_start_time = time.time()

            for i, minibatch in enumerate(ppo_dataloader):
                (
                    mb_padded_policy_log_probs, # Shape: (mb_size, max_response_len)
                    mb_padded_response_ids, # Shape: (mb_size, max_response_len)
                    mb_padded_full_sequence_ids, # Shape: (mb_size, max_total_len in this batch)
                    mb_attention_mask_full_sequence, # Shape: (mb_size, max_total_len in this batch)
                    mb_response_mask, # Shape: (mb_size, max_response_len)
                    mb_advantages, # Shape: (mb_size, max_response_len)
                    mb_returns, # Shape: (mb_size, max_response_len)
                ) = [tensor.to(device) for tensor in minibatch]

                minibatch_count += 1

                mb_total_loss, mb_policy_loss, mb_value_loss, mb_entropy = calculate_ppo_loss(
                    policy_model,
                    mb_padded_policy_log_probs,
                    mb_padded_response_ids,
                    mb_padded_full_sequence_ids,
                    mb_attention_mask_full_sequence,
                    mb_response_mask,
                    mb_advantages,
                    mb_returns,
                    clip_param=clip_param,
                    vf_coef=vf_coef,
                    ent_coef=ent_coef,
                    device=device,
                )

                optimizer.zero_grad()
                mb_total_loss.backward()
                torch.nn.utils.clip_grad_norm_(policy_model.parameters(), max_norm=0.5)
                optimizer.step()

                total_loss_epoch += mb_total_loss.item()
                policy_loss_epoch += mb_policy_loss.item()
                value_loss_epoch += mb_value_loss.item()
                entropy_epoch += mb_entropy.item()

            ppo_epoch_end_time = time.time()

            if minibatch_count > 0:
                avg_total_loss = total_loss_epoch / minibatch_count
                avg_policy_loss = policy_loss_epoch / minibatch_count
                avg_value_loss = value_loss_epoch / minibatch_count
                avg_entropy = entropy_epoch / minibatch_count
                print(f"  PPO Epoch {ppo_epoch+1}/{ppo_epochs_per_rollout} 结束: Avg Total Loss={avg_total_loss:.4f}, Avg Policy Loss={avg_policy_loss:.4f}, Avg Value Loss={avg_value_loss:.4f}, Avg Entropy={avg_entropy:.4f} (耗时: {ppo_epoch_end_time - ppo_epoch_start_time:.2f}s)")
            else:
                print(f"  PPO Epoch {ppo_epoch+1}/{ppo_epochs_per_rollout} 结束: 没有有效 Minibatch.")

In [21]:
all_prompts_for_training = [
        "写一个关于机器学习的短介绍。",
        "如何用Python实现一个简单的线性回归模型？",
        "解释一下Transformer模型的自注意力机制。",
        "PPO算法的核心思想是什么？",
        "请举一个强化学习在实际中的应用例子。",
        "神经网络中的激活函数有哪些？它们的作用是什么？",
        "什么是过拟合和欠拟合？如何解决？",
        "介绍一下卷积神经网络（CNN）的基本结构。",
        "如何评估一个分类模型的性能？",
        "写一段关于深度学习的未来展望。",
        "描述一下梯度下降算法的工作原理。",
        "请写一个关于数据清洗的步骤。",
        "什么是迁移学习？它有什么优势？",
        "生成一个简单的聊天对话。",
        "推荐一本学习PyTorch的书籍。",
        "解释一下GAN（生成对抗网络）的基本概念。",
        "写一个关于自然语言处理（NLP）的应用例子。",
        "如何在机器学习项目中选择合适的算法？",
        "介绍一下决策树的工作原理。",
        "写一段鼓励机器学习初学者的文字。",
         "请提供一个关于自然语言处理的进阶概念解释。",
         "如何使用BERT进行文本分类？",
         "描述一下Seq2Seq模型及其在机器翻译中的应用。",
         "什么是注意力机制（Attention Mechanism）？它解决了什么问题？",
         "解释一下LSTM和GRU这两种循环神经网络。",
         "推荐几个常用的自然语言处理工具库。",
         "如何进行文本数据的预处理？",
         "什么是词嵌入（Word Embedding）？",
         "介绍一下预训练语言模型的常见架构（如BERT, GPT, RoBERTa）。",
         "写一个关于情感分析的应用案例。",
] * 5

trained_policy_model_wrapper = train_ppo(
    policy_model=policy_model,
    ref_model=ref_model,
    reward_model=reward_model,
    tokenizer=tokenizer,
    reward_tokenizer=reward_tokenizer,
    all_prompts_dataset=all_prompts_for_training,
    device=device,
    epochs=3,
    rollout_batch_size=2,
    ppo_epochs_per_rollout=2,
    minibatch_size=1,
    learning_rate=5e-6,
    gamma=0.99,
    lambda_=0.95,
    clip_param=0.2,
    vf_coef=0.5,
    ent_coef=0.005,
    kl_beta=0.1,
    max_new_tokens_rollout=50, # Reduce max tokens for example
    temperature=1.0,
    top_k=50,
    top_p=0.9,
    do_sample=True,
)


--- Epoch 1/3 ---
['什么是过拟合和欠拟合？如何解决？', '写一段关于深度学习的未来展望。']
  Rollout 完成. 收集到 2 条有效数据 (耗时: 7.24s)。
  开始计算 Rewards 和 GAE...
  Rewards 数据准备完成 (耗时: 1.15s)。
  GAE 计算完成 (耗时: 1.07s, 总计: 2.22s)。
  平均 Rollout Reward: -13.7346
  PPO Epoch 1/2 结束: Avg Total Loss=22.3315, Avg Policy Loss=0.3860, Avg Value Loss=43.9277, Avg Entropy=3.6694 (耗时: 84.94s)
  PPO Epoch 2/2 结束: Avg Total Loss=13.6000, Avg Policy Loss=0.1256, Avg Value Loss=26.9861, Avg Entropy=3.7391 (耗时: 99.48s)

--- Epoch 2/3 ---
['如何评估一个分类模型的性能？', '如何使用BERT进行文本分类？']
  Rollout 完成. 收集到 2 条有效数据 (耗时: 14.97s)。
  开始计算 Rewards 和 GAE...
  Rewards 数据准备完成 (耗时: 10.11s)。
  GAE 计算完成 (耗时: 0.21s, 总计: 10.33s)。
  平均 Rollout Reward: -7.5937
  PPO Epoch 1/2 结束: Avg Total Loss=9.8873, Avg Policy Loss=0.1025, Avg Value Loss=19.5885, Avg Entropy=1.8939 (耗时: 104.07s)
  PPO Epoch 2/2 结束: Avg Total Loss=4.7049, Avg Policy Loss=0.0488, Avg Value Loss=9.3304, Avg Entropy=1.8279 (耗时: 99.00s)

--- Epoch 3/3 ---
['介绍一下决策树的工作原理。', '介绍一下卷积神经网络（CNN）的基本结构。']
  Rollout 

RuntimeError: MPS backend out of memory (MPS allocated: 16.08 GB, other allocations: 1.75 GB, max allowed: 18.13 GB). Tried to allocate 593.50 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).