In [17]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import OneHotEncoder, LabelEncoder

In [18]:
# 定义网络
class lstm_model(nn.Module):
    def __init__(self, vocab, hidden_size, num_layers, dropout=0.5):
        super(lstm_model, self).__init__()
        self.vocab = vocab 
        self.int_char = {i : char for i, char in enumerate(vocab)} 
        self.char_int = {char : i for i, char in self.int_char.items()}
        self.encoder = OneHotEncoder(sparse=False).fit(vocab.reshape(-1, 1)) 
        self.hidden_size = hidden_size
        self.num_layers = num_layers
    
        # lstm层
        self.lstm = nn.LSTM(len(vocab), hidden_size, num_layers, batch_first=True, dropout=dropout)
        
        # 全连接层
        self.linear = nn.Linear(hidden_size, len(vocab)) # 这里的输出shape是每个字符的得分
        
    def forward(self, sequence, hs=None):
        out, hs = self.lstm(sequence, hs) # lstm的输出格式：（batch_size, sequence_length, hidden_size）
        out = out.reshape(-1, self.hidden_size) # 这里需要将out转换为linear的输入格式，即(batch_size*sequence_length, hidden_size)
        output = self.linear(out) # linear的输出格式：((batch_size*sequence_length, vocab_size)
        return output, hs
        
    def onehot_encode(self, data):
        return self.encoder.transform(data)
    
    def onehot_decode(self, data):
        return self.encoder.inverse_transform(data)
    
    def label_encode(self, data):
        return np.array([self.char_int[ch] for ch in data])
    
    def label_decode(self, data):
        return np.array([self.int_char[ch] for ch in data])

In [19]:
# 定义构建新数据集的批处理方法
def get_batches(data, batch_size, seq_len):

    num_features = data.shape[1] 
    num_chars = batch_size * seq_len 
    num_batches = int(np.floor(len(data) / num_chars)) 
    need_chars = num_batches * num_chars
    targets = np.append(data[1:], data[0]).reshape(data.shape) 
    inputs = data[:need_chars] 
    targets = targets[:need_chars] 
    
    # shape转换
    inputs = inputs.reshape(batch_size, -1, num_features)
    targets = targets.reshape(batch_size, -1, num_features)
    
    # 构建新的数据集
    for i in range(0, inputs.shape[1], seq_len):
        x = inputs[:, i : i+seq_len]
        y = targets[:, i : i+seq_len]
        yield x, y


In [20]:
# 定义训练函数
def train(model, data, batch_size, seq_len, epochs, lr=0.01, valid=None):

    # 是否有cuda
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    if valid is not None:
        data = model.onehot_encode(data.reshape(-1, 1))
        valid = model.onehot_encode(valid.reshape(-1, 1))
    else:
        data = model.onehot_encode(data.reshape(-1, 1))
    # 保存损失值
    train_loss = []
    val_loss = []
    # 循环训练（验证）
    for epoch in range(epochs):
        model.train()
        hs = None # hs 等于 hidden_size,隐藏层结点
        train_ls = 0.0
        val_ls = 0.0
        for x, y in get_batches(data, batch_size, seq_len):
             # 梯度置零
            optimizer.zero_grad()
            x = torch.tensor(x).float().to(device) # 类型转换
            # 模型训练
            out, hs = model(x, hs)
            hs = ([h.data for h in hs]) 
            y = y.reshape(-1, len(model.vocab))
            y = model.onehot_decode(y)
            y = model.label_encode(y.squeeze())
            y = torch.from_numpy(y).long().to(device)
            loss = criterion(out, y.squeeze())
            loss.backward()
            optimizer.step()
            train_ls += loss.item()
        
        if valid is not None:
            # 开始验证
            model.eval()
            hs = None
            with torch.no_grad():
                for x, y in get_batches(valid, batch_size, seq_len):
                    x = torch.tensor(x).float().to(device)
                    out, hs = model(x, hs) # 预测输出
                    hs = ([h.data for h in hs])
                    
                    y = y.reshape(-1, len(model.vocab))
                    y = model.onehot_decode(y)
                    y = model.label_encode(y.squeeze())
                    y = torch.from_numpy(y).long().to(device)
                    
                    loss = criterion(out, y.squeeze())
                    val_ls += loss.item()
                    
                val_loss.append(np.mean(val_ls)) # 求出每一轮的损失均值，并累计
                
            train_loss.append(np.mean(train_ls)) # 求出每一轮的损失均值，并累计
            
        print(f'--------------Epochs{epochs} | {epoch}---------------')
        print(f'Train Loss : {train_loss[-1]}') # 这里-1为最后添加进去的loss值，即本轮batch的loss
        if val_loss:
            print(f'Val Loss : {val_loss[-1]}')
            
    # 绘制loss曲线
    plt.plot(train_loss, label='Train Loss')
    plt.plot(val_loss, label='Val Loss')
    plt.title('Loss vs Epochs')
    plt.legend()
    plt.show()

In [21]:
def throw_trash(string: str) -> str:
    reg: re.Pattern = re.compile('^[0-9a-zA-Z_]{1,}$')
    string: str = reg.sub(' ', string)
    string = re.sub( ' +',' ', string)
    string = string.replace('\n','')
    string=string.replace("\'","'")
    return string.lower()

In [22]:
with open("D:\\downloads\\anna.txt") as data:
    text = data.read()
text=throw_trash(text)
text[:100]

'chapter 1happy families are all alike; every unhappy family is unhappy in its ownway.everything was '

In [23]:
vocab = np.array(sorted(set(text)))
vocab

array([' ', '!', '"', '$', '%', '&', "'", '(', ')', '*', ',', '-', '.',
       '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';',
       '?', '@', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
       'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
       'w', 'x', 'y', 'z'], dtype='<U1')

In [24]:
# 字符的数量
vocab_size = len(vocab)

In [25]:
val_len = int(np.floor(0.2 * len(text)))

In [26]:
trainset = np.array(list(text[:-val_len]))
validset = np.array(list(text[-val_len:]))
print(trainset.shape)
print(validset.shape)

(1555604,)
(388901,)


In [28]:
hidden_size = 512
num_layers = 2
batch_size = 128
seq_len = 100
epochs = 50
lr = 0.01
model = lstm_model(vocab, hidden_size, num_layers)
model

lstm_model(
  (lstm): LSTM(56, 512, num_layers=2, batch_first=True, dropout=0.5)
  (linear): Linear(in_features=512, out_features=56, bias=True)
)

In [None]:
train(model, trainset, batch_size, seq_len, epochs, lr=lr, valid=validset)

--------------Epochs50 | 0---------------
Train Loss : 362.3254086971283
Val Loss : 84.53495693206787
--------------Epochs50 | 1---------------
Train Loss : 320.5443913936615
Val Loss : 72.57386445999146
--------------Epochs50 | 2---------------
Train Loss : 276.7701494693756
Val Loss : 63.997851610183716
--------------Epochs50 | 3---------------
Train Loss : 246.4478178024292
Val Loss : 56.10816693305969
--------------Epochs50 | 4---------------
Train Loss : 220.44631016254425
Val Loss : 51.02923321723938
--------------Epochs50 | 5---------------
Train Loss : 203.19574356079102
Val Loss : 47.72372257709503
--------------Epochs50 | 6---------------
Train Loss : 191.03313851356506
Val Loss : 45.1748366355896
--------------Epochs50 | 7---------------
Train Loss : 181.56514525413513
Val Loss : 43.388057589530945
--------------Epochs50 | 8---------------
Train Loss : 174.6627584695816
Val Loss : 42.18847990036011
--------------Epochs50 | 9---------------
Train Loss : 169.4619710445404
Val 

### 模型预测

In [13]:
def predict(model, char, top_k = None, hidden_size = None):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    model.eval()
    with torch.no_grad():
        char = np.array([char]) 
        char = char.reshape(-1, 1) 
        char_encoding = model.onehot_encode(char) 
        char_encoding = char_encoding.reshape(1, 1, -1) 
        char_tensor = torch.tensor(char_encoding, dtype=torch.float32) 
        char_tensor = char_tensor.to(device) 
        out, hidden_size = model(char_tensor, hidden_size) 
        probs = F.softmax(out, dim=1).squeeze() 

        if top_k is None:
            indices = np.arange(vocab_size)
        else:
            probs, indices = probs.topk(top_k) # 选取概率最大的前top_k个
            indices = indices.cpu().numpy()
        
        probs = probs.cpu().numpy()
        
        char_index = np.random.choice(indices, p = probs / probs.sum()) # 随机选取一个索引
        char = model.int_char[char_index] # 获取索引对应的字符
        
    return char, hidden_size

In [14]:
# 获取一个样本
def sample(model, length, top_k = None, sentence="every unhappy family "):
    hidden_size = None # 初始化
    new_sentence = [char for char in sentence] # 初始化
    for i in range(length):
        next_char, hidden_size = predict(model, new_sentence[-1], top_k = top_k, hidden_size = hidden_size) # 预测下一个字符
        new_sentence.append(next_char)
        
    return ''.join(new_sentence)

In [15]:
new_text = sample(model, 2000, top_k=5)

In [16]:
new_text

'every unhappy family mieiehthed to hte hem mind to hess anded, hast serteng, indes and inded hy thet and hy the sas inter tt andeshing.."thans, sateds how hisse to hesing hom hit has, sond her the thised then ard at an hind, sout tas ander hamilg his as tagd and this he woutt is, shout thoner the sint asditilg hond thanens hered has to handsing tor hid tom she went hus i fit shet hes hed ang tan tilid the thes he med ald tinding hed it han sound tin hime hads he her had the sit if ind sas hes, ander him to mas sor henders anded, hid hed i matthingsinged hud tom thensen his site sere at ham, ind the ming as the gadled. sestileds tad the paddsen sased he wamed, has tomligtinged he mat tas ang tissend serer hes he he sand ind and her, as thithing,s he sinds, a gase hed if sas ill the song hid ind an she mime ham thing has,,."ing illering sis ind, hes shethed indesse that her silling hom sishine ham sin hat the sises sithen hid his sering has, in he sis herese whate her hem mong, the hemp

### 保存模型

In [43]:
model_name = "lstm_model.net"

checkpoint = {
    'hidden_size' : model.hidden_size,
    'num_layers' : model.num_layers,
    'state_dict' : model.state_dict()
}

with open(model_name, 'wb') as f:
    torch.save(checkpoint, f)