In [1]:
def read_data():
    with open('./the_time_machine.txt', 'r') as txt:
        lines = txt.readlines()
    import re
    return [
        l for l in
        [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]
        if l.strip() != ''
    ]


In [2]:
lines = read_data()

In [3]:
len(lines)

3093

In [4]:
def tokenize(lines, token='word'):
    return [list(line) if token == 'char' else line.split() for line in lines]


In [5]:
class Vocab:

    def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]):
        counter = Vocab.count_corpus(tokens)
        # 对词频率排序
        self.__token_freqs = sorted(counter.items(),
                                    key=lambda x: x[1],
                                    reverse=True)
        self.index_to_token = ['<unk>'] + reserved_tokens
        self.token_to_index = {
            token: idx
            for idx, token in enumerate(self.index_to_token)
        }
        for token, freq in self.__token_freqs:
            if freq >= min_freq and token not in self.token_to_index:
                self.index_to_token.append(token)
                self.token_to_index[token] = len(self.index_to_token) - 1

    def __len__(self):
        return len(self.index_to_token)

    def get_tokens(self, indicates):
        if not isinstance(indicates, (list, tuple)):
            return self.index_to_token[indicates]
        return ''.join([self.get_tokens(index) for index in indicates])
    

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_index.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    @property
    def unk(self):
        return 0

    @property
    def token_freqs(self):
        return self.__token_freqs

    @staticmethod
    def count_corpus(tokens):
        if isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        from collections import Counter
        return Counter(tokens)

In [6]:
tokens = tokenize(lines)
vocab = Vocab(tokens)

In [7]:
vocab.token_freqs

[('the', 2477),
 ('and', 1312),
 ('of', 1286),
 ('i', 1268),
 ('a', 877),
 ('to', 766),
 ('in', 606),
 ('was', 554),
 ('that', 458),
 ('it', 452),
 ('my', 441),
 ('had', 354),
 ('as', 281),
 ('me', 281),
 ('with', 264),
 ('at', 257),
 ('for', 247),
 ('you', 212),
 ('time', 211),
 ('but', 209),
 ('this', 199),
 ('or', 162),
 ('were', 158),
 ('on', 148),
 ('not', 142),
 ('from', 137),
 ('all', 136),
 ('then', 134),
 ('is', 129),
 ('have', 129),
 ('his', 129),
 ('there', 128),
 ('by', 126),
 ('he', 126),
 ('they', 124),
 ('one', 120),
 ('upon', 115),
 ('so', 114),
 ('into', 114),
 ('little', 114),
 ('be', 112),
 ('came', 107),
 ('no', 102),
 ('gutenberg', 98),
 ('some', 95),
 ('machine', 93),
 ('could', 93),
 ('an', 92),
 ('which', 92),
 ('we', 91),
 ('their', 91),
 ('said', 89),
 ('project', 88),
 ('saw', 88),
 ('down', 87),
 ('s', 86),
 ('very', 86),
 ('them', 86),
 ('now', 79),
 ('what', 78),
 ('these', 77),
 ('about', 77),
 ('any', 75),
 ('been', 75),
 ('her', 75),
 ('up', 74),
 ('out

In [8]:
tokens[666], vocab[tokens[666]]

(['that',
  'i',
  'noticed',
  'for',
  'the',
  'first',
  'time',
  'how',
  'warm',
  'the',
  'air',
  'was'],
 [9, 4, 518, 17, 1, 98, 19, 104, 698, 1, 199, 8])

In [9]:
import tensorflow as tf


def load_corpus():
    lines = read_data()
    tokens = tokenize(lines, 'char')
    vocab = Vocab(tokens)
    corpus = [vocab[token] for line in tokens for token in line]
    return corpus, vocab


def seq_data_iter_random(corpus, batch_size, num_steps):  #@save
    """使用随机抽样生成一个小批量子序列"""
    # 从随机偏移量开始对序列进行分区，随机范围包括num_steps-1
    import random
    corpus = corpus[random.randint(0, num_steps - 1):]
    # 减去1，是因为我们需要考虑标签
    num_subseqs = (len(corpus) - 1) // num_steps
    # 长度为num_steps的子序列的起始索引
    initial_indices = list(range(0, num_subseqs * num_steps, num_steps))
    # 在随机抽样的迭代过程中，
    # 来自两个相邻的、随机的、小批量中的子序列不一定在原始序列上相邻
    random.shuffle(initial_indices)

    def data(pos):
        # 返回从pos位置开始的长度为num_steps的序列
        return corpus[pos:pos + num_steps]

    num_batches = num_subseqs // batch_size
    for i in range(0, batch_size * num_batches, batch_size):
        # 在这里，initial_indices包含子序列的随机起始索引
        initial_indices_per_batch = initial_indices[i:i + batch_size]
        X = [data(j) for j in initial_indices_per_batch]
        Y = [data(j + 1) for j in initial_indices_per_batch]
        yield tf.constant(X), tf.constant(Y)


In [10]:
corpus, vocab = load_corpus()
BATCH_SIZE = 32
NUM_STEPS = 35
train_iter = seq_data_iter_random(corpus=corpus,
                                  batch_size=BATCH_SIZE,
                                  num_steps=NUM_STEPS)


In [12]:
hiddens = 256


class RNN(tf.keras.layers.Layer):

    def __init__(self, vocab_size, hiddens, **kwargs):
        super(RNN, self).__init__(**kwargs)
        self.rnn_cell = tf.keras.layers.SimpleRNNCell(
            hiddens, kernel_initializer='glorot_uniform')
        self.rnn_layer = tf.keras.layers.RNN(self.rnn_cell,
                                             time_major=True,
                                             return_sequences=True,
                                             return_state=True)
        self.vocab_size = vocab_size
        self.dense = tf.keras.layers.Dense(vocab_size)
    
    def call(self, x, state):
        x = tf.one_hot(tf.transpose(x), self.vocab_size)
        y, *state = self.rnn_layer(x, state)
        y = self.dense(tf.reshape(y, (-1, y.shape[-1])))
        return y, state

    def predict(self, x, n_pred, vocab: Vocab):
        pred_state = self.init_state(batch_size=1, dtype='float32')
        y = [vocab[x[0]]]
        # 先预热state，其实就是记录x的隐藏状态
        for y_ in x[1:]:
            _, pred_state = self.__call__(
                tf.reshape(tf.constant(y[-1]), (1, 1)).numpy(), pred_state)
            y.append(vocab[y_])
        # 开始利用x的隐藏状态进行预测
        for _ in range(n_pred):
            y_, pred_state = self.__call__(
                tf.reshape(tf.constant(y[-1]), (1, 1)).numpy(), pred_state)
            y.append(y_.numpy().argmax(axis=1)[0])
        print(y)
        return ''.join([vocab.index_to_token[c] for c in y])
    
    def gradient_clip(self, grads, theta):
        """
            梯度裁剪，防止梯度爆炸问题
        """
        theta = tf.constant(theta, dtype="float32")
        new_grads = []
        for grad in grads:
            new_grads.append(
                tf.convert_to_tensor(grad) if isinstance(
                    grad, tf.IndexedSlices) else grad)
        # L2范数
        norm = tf.math.sqrt(
            sum((tf.reduce_sum(grad**2).numpy() for grad in new_grads)))
        norm = tf.cast(norm, "float32")
        if tf.greater(norm, theta):
            for i, grad in enumerate(new_grads):
                new_grads[i] = grad * theta / norm
        return new_grads

    def fit(self, train_iter, epochs=10, lr=1e-3):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
        for epoch in range(epochs):
            for x, y in train_iter():
                state = self.init_state(batch_size=x.shape[0], dtype='float32')
                with tf.GradientTape(persistent=True) as gt:
                    # 向前计算
                    y_hat, state = self.__call__(x, state)
                    y = tf.reshape(tf.transpose(y), (-1, 1))
                    # 损失计算
                    l = loss_fn(y, y_hat)
                grads = gt.gradient(l, self.trainable_variables)
                grads = self.gradient_clip(grads, 1)
                # print(tf.reduce_sum(grads[0]))
                optimizer.apply_gradients(zip(grads, self.trainable_variables))
            print("epoch[%d\%d loss: %.6f" % (epoch + 1, epochs, l))
    
    def init_state(self, *args, **kwargs):
        return self.rnn_cell.get_initial_state(*args, **kwargs)

net = RNN(len(vocab), 512)
def get_train_iter():
    return seq_data_iter_random(corpus=corpus,
                                  batch_size=BATCH_SIZE,
                                  num_steps=NUM_STEPS)
net.fit(train_iter=get_train_iter, epochs=200, lr=1e-3)
net.predict('hello world', 50, vocab)

epoch[1\200 loss: 3.190975
epoch[2\200 loss: 2.975364
epoch[3\200 loss: 2.937062
epoch[4\200 loss: 2.949771
epoch[5\200 loss: 2.939920
epoch[6\200 loss: 2.886942
epoch[7\200 loss: 2.872908
epoch[8\200 loss: 2.878511
epoch[9\200 loss: 2.858157
epoch[10\200 loss: 2.866439
epoch[11\200 loss: 2.836256
epoch[12\200 loss: 2.859394
epoch[13\200 loss: 2.846761
epoch[14\200 loss: 2.825873
epoch[15\200 loss: 2.844437
epoch[16\200 loss: 2.808188
epoch[17\200 loss: 2.812666
epoch[18\200 loss: 2.813853
epoch[19\200 loss: 2.848081
epoch[20\200 loss: 2.788270
epoch[21\200 loss: 2.790801
epoch[22\200 loss: 2.804196
epoch[23\200 loss: 2.778377
epoch[24\200 loss: 2.799523
epoch[25\200 loss: 2.803141
epoch[26\200 loss: 2.781680
epoch[27\200 loss: 2.747699
epoch[28\200 loss: 2.774217
epoch[29\200 loss: 2.752084
epoch[30\200 loss: 2.784050
epoch[31\200 loss: 2.811021
epoch[32\200 loss: 2.744595
epoch[33\200 loss: 2.732336
epoch[34\200 loss: 2.754540
epoch[35\200 loss: 2.745453
epoch[36\200 loss: 2.742124
e

'hello world the the the the the the the sare the the the the '