# Beam Search Pytorch

Beam search作为最基础的搜索解码算法，我们关注其实现细节。

给定束beam_size=2情况，那么在t=1时刻，会解码出beam_size=2条路径

在t=2时刻，2条路径各自有beam_size个候选token，那么就会产生2x2条路径，那么我们可以依照路径的概率和来选择beam条路径

![](./image/beam-search.png)

In [1]:
import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True)

text = "Once upon a time"
input_ids = tokenizer.encode(text, return_tensors="pt")
print(input_ids)

In [2]:
max_length = 8
beam_width = 3

# 初始化beam
beams = [(input_ids, 0.0)]
print(beams)
completed_beams = []

for i in range(max_length):
    new_beams = []

    for beam_input_ids, beam_score in beams:
        # 不直接预测next token
        # 而是拿出logits里找到beam个候选的next token
        with torch.no_grad():
            outputs = model(beam_input_ids)
            next_token_logits = outputs.logits[:, -1, :]

        if i > 4:
            next_token_logits[:, tokenizer.eos_token_id] = 10 #强制更倾向于生成eos,模拟有提前终止的场景。

        # 为什么用概率对数和
        # Beam Search 目标是找到联合概率最大的一条路径, 可以取log转化为求和
        # max(p(x3|x1,x2) *  p(x2|x1) ) -> max( log(p(x3|x1,x2)) +  p(x2|x1) )
        next_token_scores = F.log_softmax(next_token_logits, dim=-1)
        top_k_scores, top_k_tokens = torch.topk(next_token_scores, 
                                                beam_width, dim=-1)

        # 在每个候选的token里都统计路径分数
        for i in range(beam_width):
            next_token = top_k_tokens[0, i].unsqueeze(0).unsqueeze(0)
            next_score = top_k_scores[0, i].item()
            new_input_ids = torch.cat([beam_input_ids, next_token], dim=-1)
            new_score = beam_score + next_score
            new_beams.append((new_input_ids, new_score))

    # 注意到beam search候选路径会有 beam_size^2条，请思考为什么？
    print(len(new_beams))

    # 检查eos
    # 如果遇到了结束标记，将该beam加入到已完成列表中
    # thx to wangrujia debug.
    remaining_beams = []
    for beam_input_ids, beam_score in new_beams:
        if beam_input_ids[0, -1].item() == tokenizer.eos_token_id:
            print('meet eos')
            completed_beams.append((beam_input_ids, beam_score))
        else:
            remaining_beams.append((beam_input_ids, beam_score))
    new_beams = remaining_beams
    
    # 选择得分最高的beam_width个beam
    beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
    print(len(beams))
    print(beams)

    # 如果所有beam都已完成，则停止生成
    if len(completed_beams) == beam_width:
        break
        
# 如果还有未完成的beam，将它们加入到已完成列表中。
# beam_size*2 >= 实际路径 >= beam_size
# 有额外的beam产生的原因是？
# 在beam search解码过程中，有路径遇到eos token终止搜索，但是beam_size仍不会变化
print('已有完整的路径：', len(completed_beams))
completed_beams.extend(beams)
print('最终候选路径数量：', len(completed_beams))

In [3]:
for beams in completed_beams:
    generated_text = tokenizer.decode(beams[0][0], skip_special_tokens=True)
    print(generated_text)

# beam searching路径选择

In [4]:
x = [
    ["a", 0.04],
    ["b", 0.7],
    ["d", 0.2],
    ["e", 0.06],
]
result = sorted(x, key=lambda x: x[1], reverse=True)[:2]
print(result)

# 补充问题思考

Beam Search相较greedy的时间复杂度是多少？

Beam Search是全局最优的解码路径吗？

Beam Search的候选路径有多少条？

Beam Search里如果有路径提前遇见EOS后，后续的解码路径的beam是否会改变？

Beam search里如果有路径提前遇见EOS后，最终候选路径里由于长短不一，用概率和作为排序指标合理吗？