In [3]:
import torch
from torch import nn

In [4]:
class Model(nn.Module):
    def __init__(self, dataset):
        """
            嵌入层将词索引转换为词向量
            LSTM单元内部实现了可以学习长序列数据的门控机制
        """
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3
        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)
        
    def forward(self, x, prev_state):
        """ forward 函数有个prev_state 参数，这个状态被存放在模型之外，依靠手动传递 """
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state
    
    def init_state(self, sequence_length):
        """ 在每个epoch开始时调用这个函数来将状态初始化为正确的尺寸 """
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                 torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [21]:
import torch
import pandas as pd
from collections import Counter

class Dataset(torch.utils.data.Dataset):
    def __init__(self, args,):
        self.args = args
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()
        
        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)} # 将数字索引转换为单词
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)} # 将单词转换为数字索引
        self.words_indexes = [self.word_to_index[w] for w in self.words]
        
    def load_words(self):
        """ 加载数据集 """
        train_df = pd.read_csv('data/reddit-cleanjokes.txt')
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')
    
    def get_uniq_words(self):
        """ 在数据集中统计出唯一词符的数量，以定义网络词表大小和嵌入尺寸  """
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)
    
    def __len__(self):
        return len(self.words_indexes) - self.args.sequence_length
    
    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.args.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.args.sequence_length+1]),
        )

In [22]:
import argparse
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader

#from model import Model
#from dataset import Dataset

def train(dataset, model, args):
    model.train()
    dataloader = DataLoader(dataset, batch_size=args.batch_size)
    criterion = nn.CrossEntropyLoss()                    # 使用CrossEntropyLoss作为损失函数
    optimizer = optim.Adam(model.parameters(), lr=0.001) # 默认参数的Adam作为优化器，参数可以以后再调整
    
    for epoch in range(args.max_epochs):
        state_h, state_c = model.init_state(args.sequence_length)
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)
            state_h = state_h.detach()
            state_c = state_c.detach()
            loss.backward()
            optimizer.step()
        print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

def predict(dataset, model, text, next_words=100):
    model.eval()
    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))
    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))
        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])
    return words

In [25]:
parser = argparse.ArgumentParser()
parser.add_argument('--max-epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--sequence-length', type=int, default=4)

args = parser.parse_args(args=[])
args

Namespace(batch_size=256, max_epochs=10, sequence_length=4)

In [26]:
dataset = Dataset(args)
model = Model(dataset)
train(dataset, model, args)

print(predict(dataset, model, text='Knock knock. Whos there?'))

{'epoch': 0, 'batch': 0, 'loss': 8.850239753723145}
{'epoch': 0, 'batch': 1, 'loss': 8.843761444091797}
{'epoch': 0, 'batch': 2, 'loss': 8.834165573120117}
{'epoch': 0, 'batch': 3, 'loss': 8.833130836486816}
{'epoch': 0, 'batch': 4, 'loss': 8.820727348327637}
{'epoch': 0, 'batch': 5, 'loss': 8.815651893615723}
{'epoch': 0, 'batch': 6, 'loss': 8.806750297546387}
{'epoch': 0, 'batch': 7, 'loss': 8.787812232971191}
{'epoch': 0, 'batch': 8, 'loss': 8.773186683654785}
{'epoch': 0, 'batch': 9, 'loss': 8.714690208435059}
{'epoch': 0, 'batch': 10, 'loss': 8.652864456176758}
{'epoch': 0, 'batch': 11, 'loss': 8.541131019592285}
{'epoch': 0, 'batch': 12, 'loss': 8.435351371765137}
{'epoch': 0, 'batch': 13, 'loss': 8.34483528137207}
{'epoch': 0, 'batch': 14, 'loss': 8.06029224395752}
{'epoch': 0, 'batch': 15, 'loss': 7.983303070068359}
{'epoch': 0, 'batch': 16, 'loss': 7.809556484222412}
{'epoch': 0, 'batch': 17, 'loss': 7.741303443908691}
{'epoch': 0, 'batch': 18, 'loss': 7.622403621673584}
{'epo

{'epoch': 1, 'batch': 62, 'loss': 7.177131175994873}
{'epoch': 1, 'batch': 63, 'loss': 7.1053080558776855}
{'epoch': 1, 'batch': 64, 'loss': 7.232308864593506}
{'epoch': 1, 'batch': 65, 'loss': 7.141412258148193}
{'epoch': 1, 'batch': 66, 'loss': 7.135429382324219}
{'epoch': 1, 'batch': 67, 'loss': 6.949068546295166}
{'epoch': 1, 'batch': 68, 'loss': 7.16571044921875}
{'epoch': 1, 'batch': 69, 'loss': 6.892696380615234}
{'epoch': 1, 'batch': 70, 'loss': 7.320590019226074}
{'epoch': 1, 'batch': 71, 'loss': 7.262444496154785}
{'epoch': 1, 'batch': 72, 'loss': 7.1685075759887695}
{'epoch': 1, 'batch': 73, 'loss': 7.242525100708008}
{'epoch': 1, 'batch': 74, 'loss': 7.234708786010742}
{'epoch': 1, 'batch': 75, 'loss': 7.406533241271973}
{'epoch': 1, 'batch': 76, 'loss': 7.160241603851318}
{'epoch': 1, 'batch': 77, 'loss': 7.420930862426758}
{'epoch': 1, 'batch': 78, 'loss': 7.541903495788574}
{'epoch': 1, 'batch': 79, 'loss': 6.826304912567139}
{'epoch': 1, 'batch': 80, 'loss': 7.123613357

{'epoch': 3, 'batch': 29, 'loss': 7.289906024932861}
{'epoch': 3, 'batch': 30, 'loss': 6.63802433013916}
{'epoch': 3, 'batch': 31, 'loss': 6.5678510665893555}
{'epoch': 3, 'batch': 32, 'loss': 6.679880619049072}
{'epoch': 3, 'batch': 33, 'loss': 6.9062113761901855}
{'epoch': 3, 'batch': 34, 'loss': 6.866612434387207}
{'epoch': 3, 'batch': 35, 'loss': 7.114994525909424}
{'epoch': 3, 'batch': 36, 'loss': 7.0179057121276855}
{'epoch': 3, 'batch': 37, 'loss': 6.811435699462891}
{'epoch': 3, 'batch': 38, 'loss': 7.181065559387207}
{'epoch': 3, 'batch': 39, 'loss': 6.9506120681762695}
{'epoch': 3, 'batch': 40, 'loss': 7.140921115875244}
{'epoch': 3, 'batch': 41, 'loss': 6.86493444442749}
{'epoch': 3, 'batch': 42, 'loss': 7.160517692565918}
{'epoch': 3, 'batch': 43, 'loss': 6.846338748931885}
{'epoch': 3, 'batch': 44, 'loss': 6.828035831451416}
{'epoch': 3, 'batch': 45, 'loss': 6.880589962005615}
{'epoch': 3, 'batch': 46, 'loss': 7.079382419586182}
{'epoch': 3, 'batch': 47, 'loss': 7.39375114

{'epoch': 4, 'batch': 90, 'loss': 7.1789045333862305}
{'epoch': 4, 'batch': 91, 'loss': 6.663756847381592}
{'epoch': 4, 'batch': 92, 'loss': 6.90627908706665}
{'epoch': 4, 'batch': 93, 'loss': 6.308578014373779}
{'epoch': 5, 'batch': 0, 'loss': 6.72037935256958}
{'epoch': 5, 'batch': 1, 'loss': 6.693321228027344}
{'epoch': 5, 'batch': 2, 'loss': 6.6880598068237305}
{'epoch': 5, 'batch': 3, 'loss': 6.847273349761963}
{'epoch': 5, 'batch': 4, 'loss': 6.791945934295654}
{'epoch': 5, 'batch': 5, 'loss': 6.775998115539551}
{'epoch': 5, 'batch': 6, 'loss': 7.2741193771362305}
{'epoch': 5, 'batch': 7, 'loss': 7.049893856048584}
{'epoch': 5, 'batch': 8, 'loss': 6.967711448669434}
{'epoch': 5, 'batch': 9, 'loss': 6.923521041870117}
{'epoch': 5, 'batch': 10, 'loss': 6.942782402038574}
{'epoch': 5, 'batch': 11, 'loss': 6.775883197784424}
{'epoch': 5, 'batch': 12, 'loss': 6.908381462097168}
{'epoch': 5, 'batch': 13, 'loss': 7.0361433029174805}
{'epoch': 5, 'batch': 14, 'loss': 6.6259236335754395}


{'epoch': 6, 'batch': 57, 'loss': 6.413529396057129}
{'epoch': 6, 'batch': 58, 'loss': 6.349175453186035}
{'epoch': 6, 'batch': 59, 'loss': 6.505659580230713}
{'epoch': 6, 'batch': 60, 'loss': 6.370521068572998}
{'epoch': 6, 'batch': 61, 'loss': 6.575816631317139}
{'epoch': 6, 'batch': 62, 'loss': 6.540684223175049}
{'epoch': 6, 'batch': 63, 'loss': 6.4311652183532715}
{'epoch': 6, 'batch': 64, 'loss': 6.424099445343018}
{'epoch': 6, 'batch': 65, 'loss': 6.505342483520508}
{'epoch': 6, 'batch': 66, 'loss': 6.533333778381348}
{'epoch': 6, 'batch': 67, 'loss': 6.198822975158691}
{'epoch': 6, 'batch': 68, 'loss': 6.510982513427734}
{'epoch': 6, 'batch': 69, 'loss': 6.132898807525635}
{'epoch': 6, 'batch': 70, 'loss': 6.740647315979004}
{'epoch': 6, 'batch': 71, 'loss': 6.549978733062744}
{'epoch': 6, 'batch': 72, 'loss': 6.473031997680664}
{'epoch': 6, 'batch': 73, 'loss': 6.511375904083252}
{'epoch': 6, 'batch': 74, 'loss': 6.614479064941406}
{'epoch': 6, 'batch': 75, 'loss': 6.596007347

{'epoch': 8, 'batch': 24, 'loss': 6.434555530548096}
{'epoch': 8, 'batch': 25, 'loss': 6.261373996734619}
{'epoch': 8, 'batch': 26, 'loss': 5.987137794494629}
{'epoch': 8, 'batch': 27, 'loss': 6.058148384094238}
{'epoch': 8, 'batch': 28, 'loss': 6.557715892791748}
{'epoch': 8, 'batch': 29, 'loss': 6.676220893859863}
{'epoch': 8, 'batch': 30, 'loss': 5.8996405601501465}
{'epoch': 8, 'batch': 31, 'loss': 5.84095573425293}
{'epoch': 8, 'batch': 32, 'loss': 5.995049953460693}
{'epoch': 8, 'batch': 33, 'loss': 6.271839618682861}
{'epoch': 8, 'batch': 34, 'loss': 6.194241523742676}
{'epoch': 8, 'batch': 35, 'loss': 6.301849365234375}
{'epoch': 8, 'batch': 36, 'loss': 6.279446125030518}
{'epoch': 8, 'batch': 37, 'loss': 6.126622676849365}
{'epoch': 8, 'batch': 38, 'loss': 6.554077625274658}
{'epoch': 8, 'batch': 39, 'loss': 6.258509635925293}
{'epoch': 8, 'batch': 40, 'loss': 6.479681491851807}
{'epoch': 8, 'batch': 41, 'loss': 6.125746250152588}
{'epoch': 8, 'batch': 42, 'loss': 6.5323100090

{'epoch': 9, 'batch': 85, 'loss': 6.159114837646484}
{'epoch': 9, 'batch': 86, 'loss': 5.893980979919434}
{'epoch': 9, 'batch': 87, 'loss': 6.057595252990723}
{'epoch': 9, 'batch': 88, 'loss': 5.88210916519165}
{'epoch': 9, 'batch': 89, 'loss': 6.007853984832764}
{'epoch': 9, 'batch': 90, 'loss': 6.564356803894043}
{'epoch': 9, 'batch': 91, 'loss': 5.873315811157227}
{'epoch': 9, 'batch': 92, 'loss': 6.146152496337891}
{'epoch': 9, 'batch': 93, 'loss': 5.500251293182373}
['Knock', 'knock.', 'Whos', 'there?', 'finally', 'Demetri', 'a', 'yam!"', 'A', '--From', 'Because', "I'm", 'a', 'blind', 'My', 'water', 'I', 'today?"', '"I', 'a', 'blind,', '"Our', 'suckers', 'A', 'hair?', 'leek', 'Refresh', 'boats', 'A', 'byte', 'A', 'into', 'clean', 'a', '"Hey!', 'strip...', 'the', 'woman.', 'How', 'did', 'the', 'happened', "don't", 'security', 'beef', 'and', 'was', 'house?', 'favorite', 'scales', 'team?', 'of', "I'm", 'dying', 'his', 'been', 'What', 'you', 'destroy', 'Superman', 'in', 'pepper', 'tir