In [6]:
import requests
import re
import os 

if os.path.exists('shakespeare.txt'):
    pass
else:
    # 下载并保存文件
    url = 'https://www.gutenberg.org/files/100/100-0.txt'
    response = requests.get(url)

    # 移除UTF-8 BOM
    text = response.content.decode('utf-8-sig') # 可以注释掉这行代码，你会发现获得的文本含有一些乱码

    with open('shakespeare.txt', 'w', encoding='utf-8') as file:
        file.write(text)

    print("下载完成并保存为 shakespeare.txt")

In [7]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer
from transformer import Transformer

# 1. 加载数据
with open('shakespeare.txt', 'r', encoding='utf-8') as file:
    text = file.read()

# 2. 使用GPT2的分词器（或其他分词器）
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokens = tokenizer.encode(text)

# 3. 创建数据集
class ShakespeareDataset(Dataset):
    def __init__(self, tokens, seq_length):
        self.tokens = tokens
        self.seq_length = seq_length

    def __len__(self):
        return len(self.tokens) // self.seq_length

    def __getitem__(self, idx):
        start = idx * self.seq_length
        end = start + self.seq_length
        return torch.tensor(self.tokens[start:end], dtype=torch.long)


Token indices sequence length is longer than the specified maximum sequence length for this model (1686447 > 1024). Running this sequence through the model will result in indexing errors


In [8]:
from torch.utils.data import random_split

# 4. 创建数据集
class ShakespeareDataset(Dataset):
    def __init__(self, tokens, seq_length):
        self.tokens = tokens
        self.seq_length = seq_length

    def __len__(self):
        return len(self.tokens) // self.seq_length

    def __getitem__(self, idx):
        start = idx * self.seq_length
        end = start + self.seq_length
        return torch.tensor(self.tokens[start:end], dtype=torch.long)

# 5. 划分训练集和验证集
seq_length = 128
dataset = ShakespeareDataset(tokens, seq_length)
train_size = int(0.8 * len(dataset))  # 80% 训练集
val_size = len(dataset) - train_size  # 20% 验证集
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=8)

In [9]:
# 6. 定义模型、损失函数和优化器
model = Transformer(
    vocab_size=tokenizer.vocab_size,
    d_model=512,
    num_heads=8,
    num_layers=2,
    d_ff=2048,
    max_seq_len=seq_length
)

# 加载模型权重
model.load_state_dict(torch.load('/teamspace/studios/this_studio/artifacts/best_model_val_loss=0.0780.pth'))

# 将模型设置为评估模式
model.eval()

Transformer(
  (preprocessor): TransformerPreprocessor(
    (embedding): Embedding(50257, 512)
  )
  (encoder): Encoder(
    (layers): ModuleList(
      (0-1): 2 x EncoderLayer(
        (attention): MultiHeadAttention(
          (query): Linear(in_features=512, out_features=512, bias=True)
          (key): Linear(in_features=512, out_features=512, bias=True)
          (value): Linear(in_features=512, out_features=512, bias=True)
          (fc_out): Linear(in_features=512, out_features=512, bias=True)
        )
        (feed_forward): FeedForward(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Decoder(
    (layers): ModuleList(
      (0-1): 2 x DecoderLayer(
       

In [10]:
import random

# 从验证数据集中随机选取一个样本
random_index = random.randint(0, len(val_dataset) - 1)
sample = val_dataset[random_index]

# 将样本转换为模型输入格式
src = sample[:-1].unsqueeze(0)  # 去掉最后一个 token 作为输入
tgt = sample[1:].unsqueeze(0)   # 去掉第一个 token 作为目标

# 生成掩码
src_pad_idx = 0  # 假设填充索引为0
tgt_pad_idx = 0  # 假设填充索引为0
src_mask = model.make_src_mask(src, src_pad_idx)
tgt_mask = model.make_trg_mask(tgt, tgt_pad_idx)

# 进行预测
with torch.no_grad():
    output = model(src, tgt, src_mask, tgt_mask)
    predicted_tokens = torch.argmax(output, dim=-1)

# 将 token 转换回文本
input_text = tokenizer.decode(src.squeeze().tolist())
target_text = tokenizer.decode(tgt.squeeze().tolist())
predicted_text = tokenizer.decode(predicted_tokens.squeeze().tolist())

# 显示结果
print(f"输入文本: {input_text}")
print(f"目标文本: {target_text}")
print(f"预测文本: {predicted_text}")

输入文本:  makes a still-stand, running neither way.
Fain would I go to meet the Archbishop,
But many thousand reasons hold me back.
I will resolve for Scotland. There am I,
Till time and vantage crave my company.

 [_Exeunt._]

SCENE IV. London. The Boar’s head Tavern in Eastcheap.

Enter two Drawers.

FIRST DRAWER.
What the devil hast thou brought there—applejohns? Thou knowest Sir
John cannot endure an applejohn.

SECOND DRAWER
目标文本:  a still-stand, running neither way.
Fain would I go to meet the Archbishop,
But many thousand reasons hold me back.
I will resolve for Scotland. There am I,
Till time and vantage crave my company.

 [_Exeunt._]

SCENE IV. London. The Boar’s head Tavern in Eastcheap.

Enter two Drawers.

FIRST DRAWER.
What the devil hast thou brought there—applejohns? Thou knowest Sir
John cannot endure an applejohn.

SECOND DRAWER.
预测文本:  a still-stand, running neither way.
Fain would I go to meet the Archbishop,
But many thousand reasons hold me back.
I will resolve for S