In [1]:
import os
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'

In [2]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# 加载预训练模型和分词器
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')



In [3]:
# 编码初始输入
input_ids = tokenizer.encode("Hello, my name is", return_tensors='pt')

input_ids

tensor([[15496,    11,   616,  1438,   318]])

#### 第一步生成

```
past_key_values = (
    (key_layer_1, value_layer_1),
    (key_layer_2, value_layer_2),
    ...
    (key_layer_N, value_layer_N)
)

(batch_size, num_heads, seq_length, head_dim)
```


In [4]:
# 第一步生成
output = model(input_ids, use_cache=True)
next_token_logits = output.logits[:, -1, :]  # 获取最后一个时间步的 logits
past_key_values = output.past_key_values    # 缓存键和值

In [6]:
output.logits.shape

torch.Size([1, 5, 50257])

In [10]:
# "vocab_size": 50257
# "n_layer": 12,
# "n_head": 12,
model.config

GPT2Config {
  "_name_or_path": "gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.45.0.dev0",
  "use_cache": true,
  "vocab_size": 50257
}

In [12]:
len(past_key_values), past_key_values[0][0].shape

(12, torch.Size([1, 12, 5, 64]))

#### 采样下一个 token

In [None]:

# 采样下一个令牌（例如取最大概率的令牌）
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

# 第二步生成，使用缓存
output = model(next_token, past_key_values=past_key_values, use_cache=True)
next_token_logits = output.logits[:, -1, :]
past_key_values = output.past_key_values

# 重复上述步骤，直到生成结束
