呱呱

In [5]:
import collections
import random
import requests

# 下载 wiki_text.txt
def download_wiki_text(file_path):
    url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    response = requests.get(url)
    with open(file_path, 'w', encoding='utf-8') as f:
        f.write(response.text)

# 读取数据集
def load_text(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read().lower()  # 统一转换为小写，减少重复字符影响
    return text

# 统计 n-gram 频率
def build_ngram_counts(text, n):
    ngram_counts = collections.defaultdict(collections.Counter)
    
    for i in range(len(text) - n):
        prefix = tuple(text[i:i+n-1])  # 前 n-1 个字符
        next_char = text[i+n-1]        # 预测的目标字符
        ngram_counts[prefix][next_char] += 1
    
    return ngram_counts

# 统计不同 n-gram 的数量
def count_unique_ngrams(ngram_counts):
    return len(ngram_counts)

# 找出出现最多的 n-gram
def most_frequent_ngram(ngram_counts):
    return max(ngram_counts.items(), key=lambda x: sum(x[1].values()))

# 生成文本
def generate_text(ngram_counts, start_seq, length=100):
    generated_text = list(start_seq)
    
    for _ in range(length):
        prefix = tuple(generated_text[-(len(start_seq)):])
        if prefix in ngram_counts:
            next_char = random.choices(
                list(ngram_counts[prefix].keys()), 
                weights=ngram_counts[prefix].values()
            )[0]
            generated_text.append(next_char)
        else:
            break  # 若遇到未知的前缀，则停止生成
    
    return ''.join(generated_text)

# 主程序
file_path = "wiki_text.txt"
download_wiki_text(file_path)
text = load_text(file_path)

for n in [2, 3]:
    print(f"\n=== {n}-gram 统计信息 ===")
    ngram_counts = build_ngram_counts(text, n)
    print(f"不同的 {n}-gram 组合数量: {count_unique_ngrams(ngram_counts)}")
    
    most_frequent, occurrences = most_frequent_ngram(ngram_counts)
    print(f"最常出现的 {n}-gram: {most_frequent}，总计 {sum(occurrences.values())} 次")
    print(f"最可能出现的下一个字符: {occurrences.most_common(3)}")
    
    # 从最频繁的 n-gram 生成文本
    generated_text = generate_text(ngram_counts, most_frequent, length=200)
    print(f"生成文本: {generated_text}")



=== 2-gram 统计信息 ===
不同的 2-gram 组合数量: 39
最常出现的 2-gram: (' ',)，总计 169892 次
最可能出现的下一个字符: [('t', 24243), ('a', 13939), ('s', 12733)]
生成文本:  wamfar larit tl ltrthy,
winch tom f illyotheybrth n l ary lls:
llat inongormay!
i wies 'l iour, it intull,

ould! f cher noutmater g,
thounche ce bsayothas feata byedsoe
lla veft:
swoocke by, f
h is s

=== 3-gram 统计信息 ===
不同的 3-gram 组合数量: 832
最常出现的 3-gram: ('e', ' ')，总计 27965 次
最可能出现的下一个字符: [('t', 3640), ('s', 2122), ('a', 2085)]
生成文本: e youll rat, sixt--ad eyer, gin vosty 's sour, wicke en thell cou to nexcus,
thatis'd;
fie?

houtly inged unt dee i'll nobeedward:
youch hisay, sout i'll.
workmanow loore; le;
gene nothospare my noblent


part 3

In [16]:
import collections
import random
import requests
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np


# 定义数据集类
class TextDataset(Dataset):
    def __init__(self, text, n, char_to_idx):
        self.n = n
        self.char_to_idx = char_to_idx
        self.data = []
        self.labels = []
        for i in range(len(text) - n):
            self.data.append([char_to_idx[c] for c in text[i:i+n-1]])
            self.labels.append(char_to_idx[text[i+n-1]])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return torch.tensor(self.data[idx]), torch.tensor(self.labels[idx])
    

# 定义 1D-CNN 模型
class CNNTextGenerator(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_filters, kernel_size, hidden_dim):
        super(CNNTextGenerator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.conv1d = nn.Conv1d(embed_dim, num_filters, kernel_size, padding=1)
        self.fc1 = nn.Linear(num_filters * (n-1), hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x).permute(0, 2, 1)  # 调整维度适应 Conv1d
        x = torch.relu(self.conv1d(x)).flatten(start_dim=1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
# 训练模型
def train_model(model, dataloader, criterion, optimizer, epochs):
    for epoch in range(epochs):
        total_loss = 0
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader)}")


# 生成文本
def generate_cnn_text(model, start_seq, idx_to_char, char_to_idx, length=100, temperature=1.0):
    model.eval()
    generated_text = list(start_seq)

    for _ in range(length):
        input_seq = torch.tensor([[char_to_idx[c] for c in generated_text[-(n-1):]]], device=device)
        with torch.no_grad():
            output = model(input_seq)
            output = output.squeeze(0) / temperature  # 调整温度参数
            probabilities = torch.softmax(output, dim=0)
            next_char_idx = torch.multinomial(probabilities, num_samples=1).item()  # 采样
            generated_text.append(idx_to_char[next_char_idx])

    return ''.join(generated_text)



# 主程序
file_path = "shakespeare.txt"
download_wiki_text(file_path)
text = load_text(file_path)


# 构建字符索引
chars = sorted(set(text))
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for c, i in char_to_idx.items()}
vocab_size = len(chars)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n = 3  # 选择 n-gram 的值

dataset = TextDataset(text, n, char_to_idx)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# 初始化模型
model = CNNTextGenerator(vocab_size, embed_dim=32, num_filters=64, kernel_size=3, hidden_dim=128).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
train_model(model, dataloader, criterion, optimizer, epochs=10)

# 生成文本
start_seq = "th"
generated_text = generate_cnn_text(model, start_seq, idx_to_char, char_to_idx, length=200)
print("\n生成文本:", generated_text)

start_seq = "or"
generated_text = generate_cnn_text(model, start_seq, idx_to_char, char_to_idx, length=200)
print("\n生成文本:", generated_text)

start_seq = "ar"
generated_text = generate_cnn_text(model, start_seq, idx_to_char, char_to_idx, length=200)
print("\n生成文本:", generated_text)


Epoch 1, Loss: 2.0629205393468344
Epoch 2, Loss: 2.005397035963688
Epoch 3, Loss: 1.9958393552009464
Epoch 4, Loss: 1.9905648772421578
Epoch 5, Loss: 1.9873847461961567
Epoch 6, Loss: 1.9851459673930123
Epoch 7, Loss: 1.9832827137141482
Epoch 8, Loss: 1.9818782862485318
Epoch 9, Loss: 1.981221349025391
Epoch 10, Loss: 1.9799752683849654

生成文本: ther youredmus:
anightem;
areme
of of houg to mentrathrome
thend forciell rand that gaorm
my crovospet--and strese
thfuld lon,
burets the raw or pet,
wel the ser spard, as of spe oare your leadairst ge 

生成文本: ory: grew-oxshan of as wart nius:
cary ard i tend lit the tar to prall by a gives you th shark, brantelf to bood ne ind
thall'd cult seloves, to to horrathomad theit surs we itiesbuy theredwely to sover

生成文本: ard: tain, ber'd:
shourat the ame, reed parris ta:
isay, uponot dinsworticit ord:
haris prot he laiv:
i wer lif not the frowbat mang remaduke you fi, ch he taray, hattlevess but.

ve kin of hich do her'
