# Beam Search

Beam Search是一种启发式搜索算法，在每一步保留最有可能的beam_size个候选序列，而不是像贪婪搜索那样只保留一个最优解。它在计算资源和结果质量之间取得了平衡。

关键特点
- 路径扩展：每一步每个候选序列会扩展出多个可能的后续token
- 路径选择：只保留分数最高的beam_size个路径
- 概率处理：使用对数概率求和来避免数值下溢

In [2]:
import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True)

text = "Once upon a time"
input_ids = tokenizer.encode(text, return_tensors="pt")
print(input_ids)

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tensor([[7454, 2402,  257,  640]])


In [5]:
max_length = 8
beam_width = 3

# 初始化beam
beams = [(input_ids, 0.0)]  # (token序列, 累计对数概率)

completed_beams = []
# Beam Search主循环
for i in range(max_length):
    new_beams = []

    # 对每个候选序列扩展
    for beam_input_ids, beam_score in beams:
        # 获取下一个token的logits
        outputs = model(beam_input_ids)
        # outputs.logits shape: [bs, seq_len, vocab_size]
        next_token_logits = outputs.logits[:, -1, :]

        # 强制EOS token概率提高(模拟提前终止)
        if i > 4:
            next_token_logits[:, tokenizer.eos_token_id] = 10

        # 计算对数概率
        # Beam Search 目标是找到联合概率最大的一条路径, 可以取log转化为求和
        next_token_scores = F.log_softmax(next_token_logits, dim=-1)

        # 取top-k候选
        top_k_scores, top_k_tokens = torch.topk(next_token_scores, beam_width, dim=-1)

        # 扩展新路径
        for i in range(beam_width):
            next_token = top_k_tokens[0, i].unsqueeze(0).unsqueeze(0)
            next_score = top_k_scores[0, i].item()
            new_input_ids = torch.cat([beam_input_ids, next_token], dim=-1)
            # 对数概率直接相加
            new_score = beam_score + next_score
            new_beams.append((new_input_ids, new_score))

    # 处理EOS token
    remaining_beams = []
    for beam_input_ids, beam_score in new_beams:
        if beam_input_ids[0, -1].item() == tokenizer.eos_token_id:
            completed_beams.append((beam_input_ids, beam_score))
        else:
            remaining_beams.append((beam_input_ids, beam_score))

    # 选择得分最高的beam_width个路径
    beams = sorted(remaining_beams, key=lambda x: x[1], reverse=True)[:beam_width]

    # 终止条件检查
    if len(completed_beams) == beam_width:
        break