In [1]:
from IPython.display import Image

- https://huggingface.co/blog/how-to-generate
- https://huggingface.co/blog/constrained-beam-search

### search

- greedy search => beam search
    - greedy search：只选择 top1 logit 的 token
        - `[batch_size, seq_length inc]` 
    - beam search: 增加候选的数量，即束宽度：beam width
        - `[batch_size * num_beams, seq_length inc]`

In [2]:
# beam width = 1
# 示意做了全展开，事实上第二步，dog 是不会被展开的
Image(url='https://huggingface.co/blog/assets/02_how-to-generate/greedy_search.png', width=400)
# 1. The
# 2. The nice
# 3. The nice woman

In [15]:
# beam width = 2
Image(url='https://huggingface.co/blog/assets/02_how-to-generate/beam_search.png', width=400)
# 1. The
# 2. The nice
# 2. The dog
# 3. The nice woman
# 3. The dog has

### generate with beam search

- `model(input_ids)`：是一步；
- `model.generate(input_ids)`：是多步，autoregressive 的生成；
    - max_length: input + max_new_length

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

prefixes = ["Once upon a time", "Hi I am a"]
model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

input_ids = tokenizer(prefixes, return_tensors="pt").input_ids
output_ids = model.generate(input_ids, num_beams=3, max_length=20)

output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
for text in output_text:
    print(text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


input_ids: torch.Size([6, 4])
tensor([[ 7454,  2402,   257,   640],
        [ 7454,  2402,   257,   640],
        [ 7454,  2402,   257,   640],
        [17250,   314,   716,   257],
        [17250,   314,   716,   257],
        [17250,   314,   716,   257]])
input_ids: torch.Size([6, 5])
tensor([[ 7454,  2402,   257,   640,    11],
        [ 7454,  2402,   257,   640,   262],
        [ 7454,  2402,   257,   640,   314],
        [17250,   314,   716,   257,  1263],
        [17250,   314,   716,   257,   845],
        [17250,   314,   716,   257,  1310]])
input_ids: torch.Size([6, 6])
tensor([[ 7454,  2402,   257,   640,    11,   262],
        [ 7454,  2402,   257,   640,    11,   314],
        [ 7454,  2402,   257,   640,    11,   340],
        [17250,   314,   716,   257,  1263,  4336],
        [17250,   314,   716,   257,  1310,  1643],
        [17250,   314,   716,   257,   845,   922]])
input_ids: torch.Size([6, 7])
tensor([[ 7454,  2402,   257,   640,    11,   340,   373],
        

In [4]:
input_ids

tensor([[ 7454,  2402,   257,   640],
        [17250,   314,   716,   257]])

In [5]:
output_ids.shape

torch.Size([2, 20])

In [18]:
greedy_output = model.generate(input_ids, max_length=20)
greedy_output

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[ 7454,  2402,   257,   640,    11,   262,   995,   373,   257,  1295,
           286,  1049,  8737,   290,  1049,  3514,    13,   383,   995,   373],
        [17250,   314,   716,   257,  1263,  4336,   286,   262,   649,   366,
            47,  9990, 32767,     1, 14256,    13,   314,   423,   587,   284]])

In [14]:
greedy_output

tensor([[ 7454,  2402,   257,   640,    11,   262,   995,   373,   257,  1295,
           286,  1049,  8737,   290,  1049,  3514,    13,   383,   995,   373],
        [17250,   314,   716,   257,  1263,  4336,   286,   262,   649,   366,
            47,  9990, 32767,     1, 14256,    13,   314,   423,   587,   284]])

In [17]:
greedy_output = model.generate(input_ids, max_length=20, num_beams=1)
greedy_output

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[ 7454,  2402,   257,   640,    11,   262,   995,   373,   257,  1295,
           286,  1049,  8737,   290,  1049,  3514,    13,   383,   995,   373],
        [17250,   314,   716,   257,  1263,  4336,   286,   262,   649,   366,
            47,  9990, 32767,     1, 14256,    13,   314,   423,   587,   284]])

### step by step

- $\log p_1+\log p_2=\log (p_1\cdot p_2)$

In [7]:
import torch
import torch.nn.functional as F

In [9]:
def show_beam_search_steps(model, tokenizer, prefix, num_beams=3, max_steps=3):
    # 将输入文本转换为 token ids
    input_ids = tokenizer(prefix, return_tensors="pt").input_ids
    
    # 初始化 beam 状态
    current_beams = [(input_ids, 0)]  # (sequence, score)
    
    print(f"\n开始处理前缀: '{prefix}'")
    
    # 对每一步进行 beam search
    for step in range(max_steps):
        candidates = []
        print(f"\n第 {step + 1} 步:")
        
        # 对每个当前的 beam 进行扩展
        for beam_ids, beam_score in current_beams:
            # 获取模型输出
            with torch.no_grad():
                outputs = model(beam_ids)
                next_token_logits = outputs.logits[:, -1, :]
                next_token_probs = F.softmax(next_token_logits, dim=-1)
            
            # 获取前 num_beams 个最可能的下一个 token
            values, indices = torch.topk(next_token_probs, num_beams)
            
            # 为每个可能的下一个 token 创建新的候选项
            for value, index in zip(values[0], indices[0]):
                new_ids = torch.cat([beam_ids, index.unsqueeze(0).unsqueeze(0)], dim=1)
                new_score = beam_score + torch.log(value).item()
                candidates.append((new_ids, new_score))
                
                # 打印当前候选项
                new_text = tokenizer.decode(new_ids[0])
                print(f"候选项: {new_text}({new_ids[0].tolist()}) 分数: {new_score:.4f}")
        
        # 选择前 num_beams 个最佳候选项
        candidates.sort(key=lambda x: x[1], reverse=True)
        current_beams = candidates[:num_beams]
        print("\n选择的 beam:")
        for beam_ids, beam_score in current_beams:
            print(f"beam: {tokenizer.decode(beam_ids[0])}({beam_ids[0].tolist()}) 分数: {beam_score:.4f}")

In [11]:
show_beam_search_steps(model, tokenizer, prefixes[0])


开始处理前缀: 'Once upon a time'

第 1 步:
候选项: Once upon a time,([7454, 2402, 257, 640, 11]) 分数: -0.8512
候选项: Once upon a time the([7454, 2402, 257, 640, 262]) 分数: -2.7396
候选项: Once upon a time I([7454, 2402, 257, 640, 314]) 分数: -3.2029

选择的 beam:
beam: Once upon a time,([7454, 2402, 257, 640, 11]) 分数: -0.8512
beam: Once upon a time the([7454, 2402, 257, 640, 262]) 分数: -2.7396
beam: Once upon a time I([7454, 2402, 257, 640, 314]) 分数: -3.2029

第 2 步:
候选项: Once upon a time, the([7454, 2402, 257, 640, 11, 262]) 分数: -3.0524
候选项: Once upon a time, I([7454, 2402, 257, 640, 11, 314]) 分数: -3.6055
候选项: Once upon a time, it([7454, 2402, 257, 640, 11, 340]) 分数: -4.0718
候选项: Once upon a time the world([7454, 2402, 257, 640, 262, 995]) 分数: -6.5612
候选项: Once upon a time the sun([7454, 2402, 257, 640, 262, 4252]) 分数: -7.6559
候选项: Once upon a time the people([7454, 2402, 257, 640, 262, 661]) 分数: -7.7589
候选项: Once upon a time I was([7454, 2402, 257, 640, 314, 373]) 分数: -4.8048
候选项: Once upon a time I had([74

In [12]:
show_beam_search_steps(model, tokenizer, prefixes[1])


开始处理前缀: 'Hi I am a'

第 1 步:
候选项: Hi I am a big([17250, 314, 716, 257, 1263]) 分数: -3.8471
候选项: Hi I am a very([17250, 314, 716, 257, 845]) 分数: -4.0766
候选项: Hi I am a little([17250, 314, 716, 257, 1310]) 分数: -4.1127

选择的 beam:
beam: Hi I am a big([17250, 314, 716, 257, 1263]) 分数: -3.8471
beam: Hi I am a very([17250, 314, 716, 257, 845]) 分数: -4.0766
beam: Hi I am a little([17250, 314, 716, 257, 1310]) 分数: -4.1127

第 2 步:
候选项: Hi I am a big fan([17250, 314, 716, 257, 1263, 4336]) 分数: -4.2283
候选项: Hi I am a big believer([17250, 314, 716, 257, 1263, 29546]) 分数: -7.1364
候选项: Hi I am a big supporter([17250, 314, 716, 257, 1263, 15525]) 分数: -8.3071
候选项: Hi I am a very good([17250, 314, 716, 257, 845, 922]) 分数: -6.7408
候选项: Hi I am a very nice([17250, 314, 716, 257, 845, 3621]) 分数: -7.1981
候选项: Hi I am a very happy([17250, 314, 716, 257, 845, 3772]) 分数: -7.3774
候选项: Hi I am a little bit([17250, 314, 716, 257, 1310, 1643]) 分数: -6.2787
候选项: Hi I am a little confused([17250, 314, 716, 257, 1310, 1