In [8]:
import torch
from torch import nn
import tiktoken
from GPT模型架构 import GPT

In [9]:
GPT_CONFIG_124M = { 
 "vocab_size": 50257, # 词汇表大小
 "context_length": 1024, # 上下文长度
 "emb_dim": 768, # 嵌入维度
 "n_heads": 12, # 注意力头的数量
 "n_layers": 12, # 层数
 "drop_rate": 0.1, # dropout 率
 "qkv_bias": False # 查询-键-值偏置
}

In [4]:
def generate_text_simple(model, idx, max_text, context_size):
    for _ in range(max_text):
        idx = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx)
        logits = logits[:, -1, :]
        idx_text = torch.softmax(logits, dim = -1)
        idx_next = torch.argmax(idx_text, dim = -1, keepdim=True)
        idx = torch.cat((idx, idx_next), dim = -1)
    return idx

In [6]:
tokenizer = tiktoken.get_encoding("gpt2")
start_context = "Hello, I am" 
encoded = tokenizer.encode(start_context) 
print("encoded:", encoded) 
encoded_tensor = torch.tensor(encoded).unsqueeze(0) 
print("encoded_tensor.shape:", encoded_tensor.shape)

encoded: [15496, 11, 314, 716]
encoded_tensor.shape: torch.Size([1, 4])


In [14]:
model = GPT(GPT_CONFIG_124M)
model.eval() 
out = generate_text_simple( 
 model=model, 
 idx=encoded_tensor, 
 max_text=6, 
 context_size=GPT_CONFIG_124M["context_length"] 
) 
print("Output:", out) 
print("Output length:", len(out[0]))

Output: tensor([[15496,    11,   314,   716, 25879, 37941, 40879, 21185, 17790, 21271]])
Output length: 10


In [15]:
decoded_text = tokenizer.decode(out.squeeze(0).tolist()) 
print(decoded_text)

Hello, I am regimes ie bribe censorship combinationsHAHA
